diff --git a/PPOCRLabel/PPOCRLabel.py b/PPOCRLabel/PPOCRLabel.py index 7d1df9d2034b4ec927a2ac3a879861df62f78a28..b9f35aa352d5be3a77de693a6f3c1acf7469ac41 100644 --- a/PPOCRLabel/PPOCRLabel.py +++ b/PPOCRLabel/PPOCRLabel.py @@ -152,16 +152,6 @@ class MainWindow(QMainWindow): self.fileListWidget.setIconSize(QSize(25, 25)) filelistLayout.addWidget(self.fileListWidget) - self.AutoRecognition = QToolButton() - self.AutoRecognition.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) - self.AutoRecognition.setIcon(newIcon('Auto')) - autoRecLayout = QHBoxLayout() - autoRecLayout.setContentsMargins(0, 0, 0, 0) - autoRecLayout.addWidget(self.AutoRecognition) - autoRecContainer = QWidget() - autoRecContainer.setLayout(autoRecLayout) - filelistLayout.addWidget(autoRecContainer) - fileListContainer = QWidget() fileListContainer.setLayout(filelistLayout) self.fileListName = getStr('fileList') @@ -172,17 +162,30 @@ class MainWindow(QMainWindow): # ================== Key List ================== if self.kie_mode: - # self.keyList = QListWidget() self.keyList = UniqueLabelQListWidget() - # self.keyList.itemSelectionChanged.connect(self.keyListSelectionChanged) - # self.keyList.itemDoubleClicked.connect(self.editBox) - # self.keyList.itemChanged.connect(self.keyListItemChanged) + + # set key list height + key_list_height = int(QApplication.desktop().height() // 4) + if key_list_height < 50: + key_list_height = 50 + self.keyList.setMaximumHeight(key_list_height) + self.keyListDockName = getStr('keyListTitle') self.keyListDock = QDockWidget(self.keyListDockName, self) self.keyListDock.setWidget(self.keyList) self.keyListDock.setFeatures(QDockWidget.NoDockWidgetFeatures) filelistLayout.addWidget(self.keyListDock) + self.AutoRecognition = QToolButton() + self.AutoRecognition.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) + self.AutoRecognition.setIcon(newIcon('Auto')) + autoRecLayout = QHBoxLayout() + autoRecLayout.setContentsMargins(0, 0, 0, 0) + autoRecLayout.addWidget(self.AutoRecognition) + autoRecContainer = QWidget() + autoRecContainer.setLayout(autoRecLayout) + filelistLayout.addWidget(autoRecContainer) + # ================== Right Area ================== listLayout = QVBoxLayout() listLayout.setContentsMargins(0, 0, 0, 0) @@ -431,8 +434,7 @@ class MainWindow(QMainWindow): # ================== New Actions ================== edit = action(getStr('editLabel'), self.editLabel, - 'Ctrl+E', 'edit', getStr('editLabelDetail'), - enabled=False) + 'Ctrl+E', 'edit', getStr('editLabelDetail'), enabled=False) AutoRec = action(getStr('autoRecognition'), self.autoRecognition, '', 'Auto', getStr('autoRecognition'), enabled=False) @@ -465,11 +467,10 @@ class MainWindow(QMainWindow): 'Ctrl+Z', "undo", getStr("undo"), enabled=False) change_cls = action(getStr("keyChange"), self.change_box_key, - 'Ctrl+B', "edit", getStr("keyChange"), enabled=False) + 'Ctrl+X', "edit", getStr("keyChange"), enabled=False) lock = action(getStr("lockBox"), self.lockSelectedShape, - None, "lock", getStr("lockBoxDetail"), - enabled=False) + None, "lock", getStr("lockBoxDetail"), enabled=False) self.editButton.setDefaultAction(edit) self.newButton.setDefaultAction(create) @@ -534,9 +535,10 @@ class MainWindow(QMainWindow): fileMenuActions=(opendir, open_dataset_dir, saveLabel, resetAll, quit), beginner=(), advanced=(), editMenu=(createpoly, edit, copy, delete, singleRere, None, undo, undoLastPoint, - None, rotateLeft, rotateRight, None, color1, self.drawSquaresOption, lock), + None, rotateLeft, rotateRight, None, color1, self.drawSquaresOption, lock, + None, change_cls), beginnerContext=( - create, edit, copy, delete, singleRere, rotateLeft, rotateRight, lock, change_cls), + create, edit, copy, delete, singleRere, rotateLeft, rotateRight, lock, change_cls), advancedContext=(createMode, editMode, edit, copy, delete, shapeLineColor, shapeFillColor), onLoadActive=(create, createMode, editMode), @@ -1105,7 +1107,9 @@ class MainWindow(QMainWindow): shapes = [format_shape(shape) for shape in self.canvas.shapes if shape.line_color != DEFAULT_LOCK_COLOR] # Can add differrent annotation formats here for box in self.result_dic: - trans_dic = {"label": box[1][0], "points": box[0], "difficult": False, "key_cls": "None"} + trans_dic = {"label": box[1][0], "points": box[0], "difficult": False} + if self.kie_mode: + trans_dic.update({"key_cls": "None"}) if trans_dic["label"] == "" and mode == 'Auto': continue shapes.append(trans_dic) @@ -1113,8 +1117,10 @@ class MainWindow(QMainWindow): try: trans_dic = [] for box in shapes: - trans_dic.append({"transcription": box['label'], "points": box['points'], - "difficult": box['difficult'], "key_cls": box['key_cls']}) + trans_dict = {"transcription": box['label'], "points": box['points'], "difficult": box['difficult']} + if self.kie_mode: + trans_dict.update({"key_cls": box['key_cls']}) + trans_dic.append(trans_dict) self.PPlabel[annotationFilePath] = trans_dic if mode == 'Auto': self.Cachelabel[annotationFilePath] = trans_dic @@ -1424,15 +1430,17 @@ class MainWindow(QMainWindow): # box['ratio'] of the shapes saved in lockedShapes contains the ratio of the # four corner coordinates of the shapes to the height and width of the image for box in self.canvas.lockedShapes: + key_cls = None if not self.kie_mode else box['key_cls'] if self.canvas.isInTheSameImage: shapes.append((box['transcription'], [[s[0] * width, s[1] * height] for s in box['ratio']], - DEFAULT_LOCK_COLOR, box['key_cls'], box['difficult'])) + DEFAULT_LOCK_COLOR, key_cls, box['difficult'])) else: shapes.append(('锁定框:待检测', [[s[0] * width, s[1] * height] for s in box['ratio']], - DEFAULT_LOCK_COLOR, box['key_cls'], box['difficult'])) + DEFAULT_LOCK_COLOR, key_cls, box['difficult'])) if imgidx in self.PPlabel.keys(): for box in self.PPlabel[imgidx]: - shapes.append((box['transcription'], box['points'], None, box['key_cls'], box['difficult'])) + key_cls = None if not self.kie_mode else box['key_cls'] + shapes.append((box['transcription'], box['points'], None, key_cls, box['difficult'])) self.loadLabels(shapes) self.canvas.verified = False @@ -1460,6 +1468,7 @@ class MainWindow(QMainWindow): def adjustScale(self, initial=False): value = self.scalers[self.FIT_WINDOW if initial else self.zoomMode]() self.zoomWidget.setValue(int(100 * value)) + self.imageSlider.setValue(self.zoomWidget.value()) # set zoom slider value def scaleFitWindow(self): """Figure out the size of the pixmap in order to fit the main widget.""" @@ -1600,7 +1609,6 @@ class MainWindow(QMainWindow): else: self.keyDialog.labelList.addItems(self.existed_key_cls_set) - def importDirImages(self, dirpath, isDelete=False): if not self.mayContinue() or not dirpath: return @@ -2238,13 +2246,22 @@ class MainWindow(QMainWindow): print('The program will automatically save once after confirming 5 images (default)') def change_box_key(self): + if not self.kie_mode: + return key_text, _ = self.keyDialog.popUp(self.key_previous_text) if key_text is None: return self.key_previous_text = key_text for shape in self.canvas.selectedShapes: shape.key_cls = key_text + if not self.keyList.findItemsByLabel(key_text): + item = self.keyList.createItemFromLabel(key_text) + self.keyList.addItem(item) + rgb = self._get_rgb_by_label(key_text, self.kie_mode) + self.keyList.setItemLabel(item, key_text, rgb) + self._update_shape_color(shape) + self.keyDialog.addLabelHistory(key_text) def undoShapeEdit(self): self.canvas.restoreShape() @@ -2288,9 +2305,10 @@ class MainWindow(QMainWindow): shapes = [format_shape(shape) for shape in self.canvas.selectedShapes] trans_dic = [] for box in shapes: - trans_dic.append({"transcription": box['label'], "ratio": box['ratio'], - "difficult": box['difficult'], - "key_cls": "None" if "key_cls" not in box else box["key_cls"]}) + trans_dict = {"transcription": box['label'], "ratio": box['ratio'], "difficult": box['difficult']} + if self.kie_mode: + trans_dict.update({"key_cls": box["key_cls"]}) + trans_dic.append(trans_dict) self.canvas.lockedShapes = trans_dic self.actions.save.setEnabled(True) diff --git a/PPOCRLabel/README.md b/PPOCRLabel/README.md index 9d5eea048350156957b0079e27a0239ae5e482e3..21db1867aa6b6504595096de56b17f01dbf3e4f6 100644 --- a/PPOCRLabel/README.md +++ b/PPOCRLabel/README.md @@ -9,7 +9,7 @@ PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, w ### Recent Update - 2022.02:(by [PeterH0323](https://github.com/peterh0323) ) - - Added KIE mode, for [detection + identification + keyword extraction] labeling. + - Add KIE Mode by using `--kie`, for [detection + identification + keyword extraction] labeling. - 2022.01:(by [PeterH0323](https://github.com/peterh0323) ) - Improve user experience: prompt for the number of files and labels, optimize interaction, and fix bugs such as only use CPU when inference - 2021.11.17: @@ -54,7 +54,10 @@ PPOCRLabel can be started in two ways: whl package and Python script. The whl pa ```bash pip install PPOCRLabel # install -PPOCRLabel # run + +# Select label mode and run +PPOCRLabel # [Normal mode] for [detection + recognition] labeling +PPOCRLabel --kie True # [KIE mode] for [detection + recognition + keyword extraction] labeling ``` > If you getting this error `OSError: [WinError 126] The specified module could not be found` when you install shapely on windows. Please try to download Shapely whl file using http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely. @@ -67,13 +70,18 @@ PPOCRLabel # run ```bash pip3 install PPOCRLabel pip3 install trash-cli -PPOCRLabel + +# Select label mode and run +PPOCRLabel # [Normal mode] for [detection + recognition] labeling +PPOCRLabel --kie True # [KIE mode] for [detection + recognition + keyword extraction] labeling ``` #### MacOS ```bash pip3 install PPOCRLabel pip3 install opencv-contrib-python-headless==4.2.0.32 + +# Select label mode and run PPOCRLabel # [Normal mode] for [detection + recognition] labeling PPOCRLabel --kie True # [KIE mode] for [detection + recognition + keyword extraction] labeling ``` @@ -90,6 +98,8 @@ pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl ```bash cd ./PPOCRLabel # Switch to the PPOCRLabel directory + +# Select label mode and run python PPOCRLabel.py # [Normal mode] for [detection + recognition] labeling python PPOCRLabel.py --kie True # [KIE mode] for [detection + recognition + keyword extraction] labeling ``` @@ -156,6 +166,7 @@ python PPOCRLabel.py --kie True # [KIE mode] for [detection + recognition + keyw | X | Rotate the box anti-clockwise | | C | Rotate the box clockwise | | Ctrl + E | Edit label of the selected box | +| Ctrl + X | Change key class of the box when enable `--kie` | | Ctrl + R | Re-recognize the selected box | | Ctrl + C | Copy and paste the selected box | | Ctrl + Left Mouse Button | Multi select the label box | diff --git a/PPOCRLabel/README_ch.md b/PPOCRLabel/README_ch.md index a25871d9310d9c04a2f2fcbb68013938e26aa956..9728686e0232f4cdc387d579b8344b09366beafd 100644 --- a/PPOCRLabel/README_ch.md +++ b/PPOCRLabel/README_ch.md @@ -9,7 +9,7 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P #### 近期更新 - 2022.02:(by [PeterH0323](https://github.com/peterh0323) ) - - 新增:KIE 功能,用于打【检测+识别+关键字提取】的标签 + - 新增:使用 `--kie` 进入 KIE 功能,用于打【检测+识别+关键字提取】的标签 - 2022.01:(by [PeterH0323](https://github.com/peterh0323) ) - 提升用户体验:新增文件与标记数目提示、优化交互、修复gpu使用等问题 - 2021.11.17: @@ -57,7 +57,10 @@ PPOCRLabel可通过whl包与Python脚本两种方式启动,whl包形式启动 ```bash pip install PPOCRLabel # 安装 -PPOCRLabel --lang ch # 运行 + +# 选择标签模式来启动 +PPOCRLabel --lang ch # 启动【普通模式】,用于打【检测+识别】场景的标签 +PPOCRLabel --lang ch --kie True # 启动 【KIE 模式】,用于打【检测+识别+关键字提取】场景的标签 ``` > 注意:通过whl包安装PPOCRLabel会自动下载 `paddleocr` whl包,其中shapely依赖可能会出现 `[winRrror 126] 找不到指定模块的问题。` 的错误,建议从[这里](https://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely)下载并安装 ##### Ubuntu Linux @@ -65,13 +68,18 @@ PPOCRLabel --lang ch # 运行 ```bash pip3 install PPOCRLabel pip3 install trash-cli -PPOCRLabel --lang ch + +# 选择标签模式来启动 +PPOCRLabel --lang ch # 启动【普通模式】,用于打【检测+识别】场景的标签 +PPOCRLabel --lang ch --kie True # 启动 【KIE 模式】,用于打【检测+识别+关键字提取】场景的标签 ``` ##### MacOS ```bash pip3 install PPOCRLabel pip3 install opencv-contrib-python-headless==4.2.0.32 # 如果下载过慢请添加"-i https://mirror.baidu.com/pypi/simple" + +# 选择标签模式来启动 PPOCRLabel --lang ch # 启动【普通模式】,用于打【检测+识别】场景的标签 PPOCRLabel --lang ch --kie True # 启动 【KIE 模式】,用于打【检测+识别+关键字提取】场景的标签 ``` @@ -92,6 +100,8 @@ pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl -i https://mirror.baidu. ```bash cd ./PPOCRLabel # 切换到PPOCRLabel目录 + +# 选择标签模式来启动 python PPOCRLabel.py --lang ch # 启动【普通模式】,用于打【检测+识别】场景的标签 python PPOCRLabel.py --lang ch --kie True # 启动 【KIE 模式】,用于打【检测+识别+关键字提取】场景的标签 ``` @@ -137,25 +147,27 @@ python PPOCRLabel.py --lang ch --kie True # 启动 【KIE 模式】,用于打 ### 3.1 快捷键 -| 快捷键 | 说明 | -|------------------|----------------| -| Ctrl + shift + R | 对当前图片的所有标记重新识别 | -| W | 新建矩形框 | -| Q | 新建四点框 | -| X | 框逆时针旋转 | -| C | 框顺时针旋转 | -| Ctrl + E | 编辑所选框标签 | -| Ctrl + R | 重新识别所选标记 | -| Ctrl + C | 复制并粘贴选中的标记框 | -| Ctrl + 鼠标左键 | 多选标记框 | -| Backspace | 删除所选框 | -| Ctrl + V | 确认本张图片标记 | -| Ctrl + Shift + d | 删除本张图片 | -| D | 下一张图片 | -| A | 上一张图片 | -| Ctrl++ | 缩小 | -| Ctrl-- | 放大 | -| ↑→↓← | 移动标记框 | +| 快捷键 | 说明 | +|------------------|---------------------------------| +| Ctrl + shift + R | 对当前图片的所有标记重新识别 | +| W | 新建矩形框 | +| Q | 新建四点框 | +| X | 框逆时针旋转 | +| C | 框顺时针旋转 | +| Ctrl + E | 编辑所选框标签 | +| Ctrl + X | `--kie` 模式下,修改 Box 的关键字种类 | +| Ctrl + R | 重新识别所选标记 | +| Ctrl + C | 复制并粘贴选中的标记框 | +| Ctrl + 鼠标左键 | 多选标记框 | +| Backspac | 删除所选框 | +| Ctrl + V | 确认本张图片标记 | +| Ctrl + Shift + d | 删除本张图片 | +| D | 下一张图片 | +| A | 上一张图片 | +| Ctrl++ | 缩小 | +| Ctrl-- | 放大 | +| ↑→↓← | 移动标记框 | + ### 3.2 内置模型 diff --git a/PPOCRLabel/libs/canvas.py b/PPOCRLabel/libs/canvas.py index 6c1043da43a5049f1c47f58152431259df9fa36a..e6cddf13ede235fa193daf84d4395d77c371049a 100644 --- a/PPOCRLabel/libs/canvas.py +++ b/PPOCRLabel/libs/canvas.py @@ -546,7 +546,7 @@ class Canvas(QWidget): # Give up if both fail. for shape in shapes: point = shape[0] - offset = QPointF(2.0, 2.0) + offset = QPointF(5.0, 5.0) self.calculateOffsets(shape, point) self.prevPoint = point if not self.boundedMoveShape(shape, point - offset): diff --git a/PPOCRLabel/libs/unique_label_qlist_widget.py b/PPOCRLabel/libs/unique_label_qlist_widget.py index f1eff7a172d3fecf9c18579ccead5f62ba65ecd5..07ae05fe67d8b1a924d04666220e33f664891e83 100644 --- a/PPOCRLabel/libs/unique_label_qlist_widget.py +++ b/PPOCRLabel/libs/unique_label_qlist_widget.py @@ -1,6 +1,6 @@ # -*- encoding: utf-8 -*- -from PyQt5.QtCore import Qt +from PyQt5.QtCore import Qt, QSize from PyQt5 import QtWidgets @@ -40,6 +40,7 @@ class UniqueLabelQListWidget(EscapableQListWidget): qlabel.setText(' {} '.format(*color, label)) qlabel.setAlignment(Qt.AlignBottom) - item.setSizeHint(qlabel.sizeHint()) + # item.setSizeHint(qlabel.sizeHint()) + item.setSizeHint(QSize(25, 25)) self.setItemWidget(item, qlabel) diff --git a/README_ch.md b/README_ch.md index e6b8b5772e2f2069c24d200fffe121468f6dd4ae..3788f9f0d4003a0a8aa636cd1dd6148936598411 100755 --- a/README_ch.md +++ b/README_ch.md @@ -32,7 +32,7 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力 - PP-OCR系列高质量预训练模型,准确的识别效果 - 超轻量PP-OCRv2系列:检测(3.1M)+ 方向分类器(1.4M)+ 识别(8.5M)= 13.0M - 超轻量PP-OCR mobile移动端系列:检测(3.0M)+方向分类器(1.4M)+ 识别(5.0M)= 9.4M - - 通用PPOCR server系列:检测(47.1M)+方向分类器(1.4M)+ 识别(94.9M)= 143.4M + - 通用PP-OCR server系列:检测(47.1M)+方向分类器(1.4M)+ 识别(94.9M)= 143.4M - 支持中英文数字组合识别、竖排文本识别、长文本识别 - 支持多语言识别:韩语、日语、德语、法语等约80种语言 - PP-Structure文档结构化系统 diff --git a/benchmark/run_benchmark_det.sh b/benchmark/run_benchmark_det.sh index 818aa7e3e1fb342174a0cf5be4d45af0b0205a39..9f5b46cde1da58ccc3fbb128e52aa1cfe4f3dd53 100644 --- a/benchmark/run_benchmark_det.sh +++ b/benchmark/run_benchmark_det.sh @@ -1,5 +1,4 @@ #!/usr/bin/env bash -set -xe # 运行示例:CUDA_VISIBLE_DEVICES=0 bash run_benchmark.sh ${run_mode} ${bs_item} ${fp_item} 500 ${model_mode} # 参数说明 function _set_params(){ @@ -34,11 +33,13 @@ function _train(){ train_cmd="python tools/train.py "${train_cmd}"" ;; mp) + rm -rf ./mylog train_cmd="python -m paddle.distributed.launch --log_dir=./mylog --gpus=$CUDA_VISIBLE_DEVICES tools/train.py ${train_cmd}" ;; *) echo "choose run_mode(sp or mp)"; exit 1; esac # 以下不用修改 + echo ${train_cmd} timeout 15m ${train_cmd} > ${log_file} 2>&1 if [ $? -ne 0 ];then echo -e "${model_name}, FAIL" diff --git a/configs/det/det_mv3_pse.yml b/configs/det/det_mv3_pse.yml index 61ac24727acbd4f0b1eea15af08c0f9e71ce95a3..f80180ce7c1604cfda42ede36930e1bd9fdb8e21 100644 --- a/configs/det/det_mv3_pse.yml +++ b/configs/det/det_mv3_pse.yml @@ -56,7 +56,7 @@ PostProcess: thresh: 0 box_thresh: 0.85 min_area: 16 - box_type: box # 'box' or 'poly' + box_type: quad # 'quad' or 'poly' scale: 1 Metric: diff --git a/configs/det/det_r50_vd_dcn_fce_ctw.yml b/configs/det/det_r50_vd_dcn_fce_ctw.yml new file mode 100755 index 0000000000000000000000000000000000000000..a9f7c4143d4e9380c819f8cbc39d69f0149111b2 --- /dev/null +++ b/configs/det/det_r50_vd_dcn_fce_ctw.yml @@ -0,0 +1,139 @@ +Global: + use_gpu: true + epoch_num: 1500 + log_smooth_window: 20 + print_batch_step: 20 + save_model_dir: ./output/det_r50_dcn_fce_ctw/ + save_epoch_step: 100 + # evaluation is run every 835 iterations + eval_batch_step: [0, 835] + cal_metric_during_train: False + pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_fce/predicts_fce.txt + + +Architecture: + model_type: det + algorithm: FCE + Transform: + Backbone: + name: ResNet + layers: 50 + dcn_stage: [False, True, True, True] + out_indices: [1,2,3] + Neck: + name: FCEFPN + out_channels: 256 + has_extra_convs: False + extra_stage: 0 + Head: + name: FCEHead + fourier_degree: 5 +Loss: + name: FCELoss + fourier_degree: 5 + num_sample: 50 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + learning_rate: 0.0001 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: FCEPostProcess + scales: [8, 16, 32] + alpha: 1.0 + beta: 1.0 + fourier_degree: 5 + box_type: 'poly' + +Metric: + name: DetFCEMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ctw1500/imgs/ + label_file_list: + - ./train_data/ctw1500/imgs/training.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + ignore_orientation: True + - DetLabelEncode: # Class handling label + - ColorJitter: + brightness: 0.142 + saturation: 0.5 + contrast: 0.5 + - RandomScaling: + - RandomCropFlip: + crop_ratio: 0.5 + - RandomCropPolyInstances: + crop_ratio: 0.8 + min_side_ratio: 0.3 + - RandomRotatePolyInstances: + rotate_ratio: 0.5 + max_angle: 30 + pad_with_fixed_color: False + - SquareResizePad: + target_size: 800 + pad_ratio: 0.6 + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - FCENetTargets: + fourier_degree: 5 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'p3_maps', 'p4_maps', 'p5_maps'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 6 + num_workers: 8 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ctw1500/imgs/ + label_file_list: + - ./train_data/ctw1500/imgs/test.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + ignore_orientation: True + - DetLabelEncode: # Class handling label + - DetResizeForTest: + limit_type: 'min' + limit_side_len: 736 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - Pad: + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 \ No newline at end of file diff --git a/configs/det/det_r50_vd_pse.yml b/configs/det/det_r50_vd_pse.yml index 4629210747d3b61344cc47b11dcff01e6b738586..8e77506c410af5397a04f73674b414cb28a87c4d 100644 --- a/configs/det/det_r50_vd_pse.yml +++ b/configs/det/det_r50_vd_pse.yml @@ -55,7 +55,7 @@ PostProcess: thresh: 0 box_thresh: 0.85 min_area: 16 - box_type: box # 'box' or 'poly' + box_type: quad # 'quad' or 'poly' scale: 1 Metric: diff --git a/configs/rec/rec_efficientb3_fpn_pren.yml b/configs/rec/rec_efficientb3_fpn_pren.yml new file mode 100644 index 0000000000000000000000000000000000000000..0fac6a7a8abcc5c1e4301b1040e6e98df74abed9 --- /dev/null +++ b/configs/rec/rec_efficientb3_fpn_pren.yml @@ -0,0 +1,92 @@ +Global: + use_gpu: True + epoch_num: 8 + log_smooth_window: 20 + print_batch_step: 5 + save_model_dir: ./output/rec/pren_new + save_epoch_step: 3 + # evaluation is run every 2000 iterations after the 4000th iteration + eval_batch_step: [4000, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: + max_text_length: &max_text_length 25 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_pren.txt + +Optimizer: + name: Adadelta + lr: + name: Piecewise + decay_epochs: [2, 5, 7] + values: [0.5, 0.1, 0.01, 0.001] + +Architecture: + model_type: rec + algorithm: PREN + in_channels: 3 + Backbone: + name: EfficientNetb3_PREN + Neck: + name: PRENFPN + n_r: 5 + d_model: 384 + max_len: *max_text_length + dropout: 0.1 + Head: + name: PRENHead + +Loss: + name: PRENLoss + +PostProcess: + name: PRENLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/training/ + transforms: + - DecodeImage: + img_mode: BGR + channel_first: False + - PRENLabelEncode: + - RecAug: + - PRENResizeImg: + image_shape: [64, 256] # h,w + - KeepKeys: + keep_keys: ['image', 'label'] + loader: + shuffle: True + batch_size_per_card: 128 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/validation/ + transforms: + - DecodeImage: + img_mode: BGR + channel_first: False + - PRENLabelEncode: + - PRENResizeImg: + image_shape: [64, 256] # h,w + - KeepKeys: + keep_keys: ['image', 'label'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 64 + num_workers: 8 diff --git a/deploy/android_demo/README.md b/deploy/android_demo/README.md index 1642323ebb347dfcaf0f14f0fc00c139065f53cf..ba615fba904eed9686f645b51bf7f9821b555653 100644 --- a/deploy/android_demo/README.md +++ b/deploy/android_demo/README.md @@ -1,19 +1,118 @@ -# 如何快速测试 -### 1. 安装最新版本的Android Studio -可以从 https://developer.android.com/studio 下载。本Demo使用是4.0版本Android Studio编写。 +- [Android Demo](#android-demo) + - [1. 简介](#1-简介) + - [2. 近期更新](#2-近期更新) + - [3. 快速使用](#3-快速使用) + - [3.1 环境准备](#31-环境准备) + - [3.2 导入项目](#32-导入项目) + - [3.3 运行demo](#33-运行demo) + - [3.4 运行模式](#34-运行模式) + - [3.5 设置](#35-设置) + - [4 更多支持](#4-更多支持) -### 2. 按照NDK 20 以上版本 -Demo测试的时候使用的是NDK 20b版本,20版本以上均可以支持编译成功。 +# Android Demo -如果您是初学者,可以用以下方式安装和测试NDK编译环境。 -点击 File -> New ->New Project, 新建 "Native C++" project +## 1. 简介 +此为PaddleOCR的Android Demo,目前支持文本检测,文本方向分类器和文本识别模型的使用。使用 [PaddleLite v2.10](https://github.com/PaddlePaddle/Paddle-Lite/tree/release/v2.10) 进行开发。 + +## 2. 近期更新 +* 2022.02.27 + * 预测库更新到PaddleLite v2.10 + * 支持6种运行模式: + * 检测+分类+识别 + * 检测+识别 + * 分类+识别 + * 检测 + * 识别 + * 分类 + +## 3. 快速使用 + +### 3.1 环境准备 +1. 在本地环境安装好 Android Studio 工具,详细安装方法请见[Android Stuido 官网](https://developer.android.com/studio)。 +2. 准备一部 Android 手机,并开启 USB 调试模式。开启方法: `手机设置 -> 查找开发者选项 -> 打开开发者选项和 USB 调试模式` + +**注意**:如果您的 Android Studio 尚未配置 NDK ,请根据 Android Studio 用户指南中的[安装及配置 NDK 和 CMake ](https://developer.android.com/studio/projects/install-ndk)内容,预先配置好 NDK 。您可以选择最新的 NDK 版本,或者使用 Paddle Lite 预测库版本一样的 NDK + +### 3.2 导入项目 -### 3. 导入项目 点击 File->New->Import Project..., 然后跟着Android Studio的引导导入 +导入完成后呈现如下界面 +![](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/imgs/import_demo.jpg) + +### 3.3 运行demo +将手机连接上电脑后,点击Android Studio工具栏中的运行按钮即可运行demo。在此过程中,手机会弹出"允许从 USB 安装软件权限"的弹窗,点击允许即可。 + +软件安转到手机上后会在手机主屏最后一页看到如下app +
+ +
+ +点击app图标即可启动app,启动后app主页如下 + +
+ +
+ +app主页中有四个按钮,一个下拉列表和一个菜单按钮,他们的功能分别为 + +* 运行模型:按照已选择的模式,运行对应的模型组合 +* 拍照识别:唤起手机相机拍照并获取拍照的图像,拍照完成后需要点击运行模型进行识别 +* 选取图片:唤起手机相册拍照选择图像,选择完成后需要点击运行模型进行识别 +* 清空绘图:清空当前显示图像上绘制的文本框,以便进行下一次识别(每次识别使用的图像都是当前显示的图像) +* 下拉列表:进行运行模式的选择,目前包含6种运行模式,默认模式为**检测+分类+识别**详细说明见下一节。 +* 菜单按钮:点击后会进入菜单界面,进行模型和内置图像有关设置 + +点击运行模型后,会按照所选择的模式运行对应的模型,**检测+分类+识别**模式下运行的模型结果如下所示: + + + +模型运行完成后,模型和运行状态显示区`STATUS`字段显示了当前模型的运行状态,这里显示为`run model successed`表明模型运行成功。 + +模型的运行结果显示在运行结果显示区,显示格式为 +```text +序号:Det:(x1,y1)(x2,y2)(x3,y3)(x4,y4) Rec: 识别文本,识别置信度 Cls:分类类别,分类分时 +``` + +### 3.4 运行模式 + +PaddleOCR demo共提供了6种运行模式,如下图 +
+ +
+ +每种模式的运行结果如下表所示 + +| 检测+分类+识别 | 检测+识别 | 分类+识别 | +|------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------| +| | | | + + +| 检测 | 识别 | 分类 | +|----------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------| +| | | | + +### 3.5 设置 + +设置界面如下 +
+ +
-# 获得更多支持 -前往[端计算模型生成平台EasyEdge](https://ai.baidu.com/easyedge/app/open_source_demo?referrerUrl=paddlelite),获得更多开发支持: +在设置界面可以进行如下几项设定: +1. 普通设置 + * Enable custom settings: 选中状态下才能更改设置 + * Model Path: 所运行的模型地址,使用默认值就好 + * Label Path: 识别模型的字典 + * Image Path: 进行识别的内置图像名 +2. 模型运行态设置,此项设置更改后返回主界面时,会自动重新加载模型 + * CPU Thread Num: 模型运行使用的CPU核心数量 + * CPU Power Mode: 模型运行模式,大小核设定 +3. 输入设置 + * det long size: DB模型预处理时图像的长边长度,超过此长度resize到该值,短边进行等比例缩放,小于此长度不进行处理。 +4. 输出设置 + * Score Threshold: DB模型后处理box的阈值,低于此阈值的box进行过滤,不显示。 -- Demo APP:可使用手机扫码安装,方便手机端快速体验文字识别 -- SDK:模型被封装为适配不同芯片硬件和操作系统SDK,包括完善的接口,方便进行二次开发 +## 4 更多支持 +1. 实时识别,更新预测库可参考 https://github.com/PaddlePaddle/Paddle-Lite-Demo/tree/develop/ocr/android/app/cxx/ppocr_demo +2. 更多Paddle-Lite相关问题可前往[Paddle-Lite](https://github.com/PaddlePaddle/Paddle-Lite) ,获得更多开发支持 diff --git a/deploy/android_demo/app/build.gradle b/deploy/android_demo/app/build.gradle index cff8d5f61142ef52cfb838923c67a7de335aeb5c..2607f32eca5b6b5612960127cdfc09c78989f3b1 100644 --- a/deploy/android_demo/app/build.gradle +++ b/deploy/android_demo/app/build.gradle @@ -8,8 +8,8 @@ android { applicationId "com.baidu.paddle.lite.demo.ocr" minSdkVersion 23 targetSdkVersion 29 - versionCode 1 - versionName "1.0" + versionCode 2 + versionName "2.0" testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" externalNativeBuild { cmake { @@ -17,11 +17,6 @@ android { arguments '-DANDROID_PLATFORM=android-23', '-DANDROID_STL=c++_shared' ,"-DANDROID_ARM_NEON=TRUE" } } - ndk { - // abiFilters "arm64-v8a", "armeabi-v7a" - abiFilters "arm64-v8a", "armeabi-v7a" - ldLibs "jnigraphics" - } } buildTypes { release { @@ -48,7 +43,7 @@ dependencies { def archives = [ [ - 'src' : 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/paddle_lite_libs_v2_9_0.tar.gz', + 'src' : 'https://paddleocr.bj.bcebos.com/libs/paddle_lite_libs_v2_10.tar.gz', 'dest': 'PaddleLite' ], [ @@ -56,7 +51,7 @@ def archives = [ 'dest': 'OpenCV' ], [ - 'src' : 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ocr_v2_for_cpu.tar.gz', + 'src' : 'https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_PP-OCRv2.tar.gz', 'dest' : 'src/main/assets/models' ], [ diff --git a/deploy/android_demo/app/src/main/AndroidManifest.xml b/deploy/android_demo/app/src/main/AndroidManifest.xml index 54482b1dcc9de66021d0109e5683302c8445ba6a..133f35703c39ed92b91a453b98d964acc8545c63 100644 --- a/deploy/android_demo/app/src/main/AndroidManifest.xml +++ b/deploy/android_demo/app/src/main/AndroidManifest.xml @@ -14,7 +14,6 @@ android:roundIcon="@mipmap/ic_launcher_round" android:supportsRtl="true" android:theme="@style/AppTheme"> - diff --git a/deploy/android_demo/app/src/main/assets/images/0.jpg b/deploy/android_demo/app/src/main/assets/images/det_0.jpg similarity index 100% rename from deploy/android_demo/app/src/main/assets/images/0.jpg rename to deploy/android_demo/app/src/main/assets/images/det_0.jpg diff --git a/deploy/android_demo/app/src/main/assets/images/180.jpg b/deploy/android_demo/app/src/main/assets/images/det_180.jpg similarity index 100% rename from deploy/android_demo/app/src/main/assets/images/180.jpg rename to deploy/android_demo/app/src/main/assets/images/det_180.jpg diff --git a/deploy/android_demo/app/src/main/assets/images/270.jpg b/deploy/android_demo/app/src/main/assets/images/det_270.jpg similarity index 100% rename from deploy/android_demo/app/src/main/assets/images/270.jpg rename to deploy/android_demo/app/src/main/assets/images/det_270.jpg diff --git a/deploy/android_demo/app/src/main/assets/images/90.jpg b/deploy/android_demo/app/src/main/assets/images/det_90.jpg similarity index 100% rename from deploy/android_demo/app/src/main/assets/images/90.jpg rename to deploy/android_demo/app/src/main/assets/images/det_90.jpg diff --git a/deploy/android_demo/app/src/main/assets/images/rec_0.jpg b/deploy/android_demo/app/src/main/assets/images/rec_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2c34cd33eac5766a072fde041fa6c9b1d612f1db Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/rec_0.jpg differ diff --git a/deploy/android_demo/app/src/main/assets/images/rec_0_180.jpg b/deploy/android_demo/app/src/main/assets/images/rec_0_180.jpg new file mode 100644 index 0000000000000000000000000000000000000000..02bc3b9279ad6c8e91fd3f169f1613c603094c44 Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/rec_0_180.jpg differ diff --git a/deploy/android_demo/app/src/main/assets/images/rec_1.jpg b/deploy/android_demo/app/src/main/assets/images/rec_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..22031ba501c7337bf99e5f0c5a687196d7d27f63 Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/rec_1.jpg differ diff --git a/deploy/android_demo/app/src/main/assets/images/rec_1_180.jpg b/deploy/android_demo/app/src/main/assets/images/rec_1_180.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d74553016e520e667523c1d37c899f7cd1d4f3bc Binary files /dev/null and b/deploy/android_demo/app/src/main/assets/images/rec_1_180.jpg differ diff --git a/deploy/android_demo/app/src/main/cpp/native.cpp b/deploy/android_demo/app/src/main/cpp/native.cpp index 963c5246d5b7b50720f92705d288526ae2cc6a73..ced932556f09244d1e9e962e7b75461203a7cc3a 100644 --- a/deploy/android_demo/app/src/main/cpp/native.cpp +++ b/deploy/android_demo/app/src/main/cpp/native.cpp @@ -13,7 +13,7 @@ static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode); extern "C" JNIEXPORT jlong JNICALL Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init( JNIEnv *env, jobject thiz, jstring j_det_model_path, - jstring j_rec_model_path, jstring j_cls_model_path, jint j_thread_num, + jstring j_rec_model_path, jstring j_cls_model_path, jint j_use_opencl, jint j_thread_num, jstring j_cpu_mode) { std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path); std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path); @@ -21,6 +21,7 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init( int thread_num = j_thread_num; std::string cpu_mode = jstring_to_cpp_string(env, j_cpu_mode); ppredictor::OCR_Config conf; + conf.use_opencl = j_use_opencl; conf.thread_num = thread_num; conf.mode = str_to_cpu_mode(cpu_mode); ppredictor::OCR_PPredictor *orc_predictor = @@ -57,32 +58,31 @@ str_to_cpu_mode(const std::string &cpu_mode) { extern "C" JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward( - JNIEnv *env, jobject thiz, jlong java_pointer, jfloatArray buf, - jfloatArray ddims, jobject original_image) { + JNIEnv *env, jobject thiz, jlong java_pointer, jobject original_image,jint j_max_size_len, jint j_run_det, jint j_run_cls, jint j_run_rec) { LOGI("begin to run native forward"); if (java_pointer == 0) { LOGE("JAVA pointer is NULL"); return cpp_array_to_jfloatarray(env, nullptr, 0); } + cv::Mat origin = bitmap_to_cv_mat(env, original_image); if (origin.size == 0) { LOGE("origin bitmap cannot convert to CV Mat"); return cpp_array_to_jfloatarray(env, nullptr, 0); } + + int max_size_len = j_max_size_len; + int run_det = j_run_det; + int run_cls = j_run_cls; + int run_rec = j_run_rec; + ppredictor::OCR_PPredictor *ppredictor = (ppredictor::OCR_PPredictor *)java_pointer; - std::vector dims_float_arr = jfloatarray_to_float_vector(env, ddims); std::vector dims_arr; - dims_arr.resize(dims_float_arr.size()); - std::copy(dims_float_arr.cbegin(), dims_float_arr.cend(), dims_arr.begin()); - - // 这里值有点大,就不调用jfloatarray_to_float_vector了 - int64_t buf_len = (int64_t)env->GetArrayLength(buf); - jfloat *buf_data = env->GetFloatArrayElements(buf, JNI_FALSE); - float *data = (jfloat *)buf_data; std::vector results = - ppredictor->infer_ocr(dims_arr, data, buf_len, NET_OCR, origin); + ppredictor->infer_ocr(origin, max_size_len, run_det, run_cls, run_rec); LOGI("infer_ocr finished with boxes %ld", results.size()); + // 这里将std::vector 序列化成 // float数组,传输到java层再反序列化 std::vector float_arr; @@ -90,13 +90,18 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward( float_arr.push_back(r.points.size()); float_arr.push_back(r.word_index.size()); float_arr.push_back(r.score); + // add det point for (const std::vector &point : r.points) { float_arr.push_back(point.at(0)); float_arr.push_back(point.at(1)); } + // add rec word idx for (int index : r.word_index) { float_arr.push_back(index); } + // add cls result + float_arr.push_back(r.cls_label); + float_arr.push_back(r.cls_score); } return cpp_array_to_jfloatarray(env, float_arr.data(), float_arr.size()); } diff --git a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp index c68456e17acbafbcbceffddd9f521e0c9bfcf774..1bd989c9da9f644fe485d6e00aae6d1793114cd0 100644 --- a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp +++ b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.cpp @@ -17,15 +17,15 @@ int OCR_PPredictor::init(const std::string &det_model_content, const std::string &rec_model_content, const std::string &cls_model_content) { _det_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR, _config.mode}); + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR, _config.mode}); _det_predictor->init_nb(det_model_content); _rec_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode}); _rec_predictor->init_nb(rec_model_content); _cls_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode}); _cls_predictor->init_nb(cls_model_content); return RETURN_OK; } @@ -34,15 +34,16 @@ int OCR_PPredictor::init_from_file(const std::string &det_model_path, const std::string &rec_model_path, const std::string &cls_model_path) { _det_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR, _config.mode}); + new PPredictor{_config.use_opencl, _config.thread_num, NET_OCR, _config.mode}); _det_predictor->init_from_file(det_model_path); + _rec_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode}); _rec_predictor->init_from_file(rec_model_path); _cls_predictor = std::unique_ptr( - new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode}); + new PPredictor{_config.use_opencl,_config.thread_num, NET_OCR_INTERNAL, _config.mode}); _cls_predictor->init_from_file(cls_model_path); return RETURN_OK; } @@ -77,90 +78,173 @@ visual_img(const std::vector>> &filter_boxes, } std::vector -OCR_PPredictor::infer_ocr(const std::vector &dims, - const float *input_data, int input_len, int net_flag, - cv::Mat &origin) { +OCR_PPredictor::infer_ocr(cv::Mat &origin,int max_size_len, int run_det, int run_cls, int run_rec) { + LOGI("ocr cpp start *****************"); + LOGI("ocr cpp det: %d, cls: %d, rec: %d", run_det, run_cls, run_rec); + std::vector ocr_results; + if(run_det){ + infer_det(origin, max_size_len, ocr_results); + } + if(run_rec){ + if(ocr_results.size()==0){ + OCRPredictResult res; + ocr_results.emplace_back(std::move(res)); + } + for(int i = 0; i < ocr_results.size();i++) { + infer_rec(origin, run_cls, ocr_results[i]); + } + }else if(run_cls){ + ClsPredictResult cls_res = infer_cls(origin); + OCRPredictResult res; + res.cls_score = cls_res.cls_score; + res.cls_label = cls_res.cls_label; + ocr_results.push_back(res); + } + + LOGI("ocr cpp end *****************"); + return ocr_results; +} + +cv::Mat DetResizeImg(const cv::Mat img, int max_size_len, + std::vector &ratio_hw) { + int w = img.cols; + int h = img.rows; + + float ratio = 1.f; + int max_wh = w >= h ? w : h; + if (max_wh > max_size_len) { + if (h > w) { + ratio = static_cast(max_size_len) / static_cast(h); + } else { + ratio = static_cast(max_size_len) / static_cast(w); + } + } + + int resize_h = static_cast(float(h) * ratio); + int resize_w = static_cast(float(w) * ratio); + if (resize_h % 32 == 0) + resize_h = resize_h; + else if (resize_h / 32 < 1 + 1e-5) + resize_h = 32; + else + resize_h = (resize_h / 32 - 1) * 32; + + if (resize_w % 32 == 0) + resize_w = resize_w; + else if (resize_w / 32 < 1 + 1e-5) + resize_w = 32; + else + resize_w = (resize_w / 32 - 1) * 32; + + cv::Mat resize_img; + cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); + + ratio_hw.push_back(static_cast(resize_h) / static_cast(h)); + ratio_hw.push_back(static_cast(resize_w) / static_cast(w)); + return resize_img; +} + +void OCR_PPredictor::infer_det(cv::Mat &origin, int max_size_len, std::vector &ocr_results) { + std::vector mean = {0.485f, 0.456f, 0.406f}; + std::vector scale = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; + PredictorInput input = _det_predictor->get_first_input(); - input.set_dims(dims); - input.set_data(input_data, input_len); + + std::vector ratio_hw; + cv::Mat input_image = DetResizeImg(origin, max_size_len, ratio_hw); + input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f); + const float *dimg = reinterpret_cast(input_image.data); + int input_size = input_image.rows * input_image.cols; + + input.set_dims({1, 3, input_image.rows, input_image.cols}); + + neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean, + scale); + LOGI("ocr cpp det shape %d,%d", input_image.rows,input_image.cols); std::vector results = _det_predictor->infer(); PredictorOutput &res = results.at(0); std::vector>> filtered_box = calc_filtered_boxes( - res.get_float_data(), res.get_size(), (int)dims[2], (int)dims[3], origin); - LOGI("Filter_box size %ld", filtered_box.size()); - return infer_rec(filtered_box, origin); + res.get_float_data(), res.get_size(), input_image.rows, input_image.cols, origin); + LOGI("ocr cpp det Filter_box size %ld", filtered_box.size()); + + for(int i = 0;i OCR_PPredictor::infer_rec( - const std::vector>> &boxes, - const cv::Mat &origin_img) { +void OCR_PPredictor::infer_rec(const cv::Mat &origin_img, int run_cls, OCRPredictResult& ocr_result) { std::vector mean = {0.5f, 0.5f, 0.5f}; std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; std::vector dims = {1, 3, 0, 0}; - std::vector ocr_results; PredictorInput input = _rec_predictor->get_first_input(); - for (auto bp = boxes.crbegin(); bp != boxes.crend(); ++bp) { - const std::vector> &box = *bp; - cv::Mat crop_img = get_rotate_crop_image(origin_img, box); - crop_img = infer_cls(crop_img); - float wh_ratio = float(crop_img.cols) / float(crop_img.rows); - cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio); - input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f); - const float *dimg = reinterpret_cast(input_image.data); - int input_size = input_image.rows * input_image.cols; + const std::vector> &box = ocr_result.points; + cv::Mat crop_img; + if(box.size()>0){ + crop_img = get_rotate_crop_image(origin_img, box); + } + else{ + crop_img = origin_img; + } - dims[2] = input_image.rows; - dims[3] = input_image.cols; - input.set_dims(dims); + if(run_cls){ + ClsPredictResult cls_res = infer_cls(crop_img); + crop_img = cls_res.img; + ocr_result.cls_score = cls_res.cls_score; + ocr_result.cls_label = cls_res.cls_label; + } - neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean, - scale); - std::vector results = _rec_predictor->infer(); - const float *predict_batch = results.at(0).get_float_data(); - const std::vector predict_shape = results.at(0).get_shape(); + float wh_ratio = float(crop_img.cols) / float(crop_img.rows); + cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio); + input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f); + const float *dimg = reinterpret_cast(input_image.data); + int input_size = input_image.rows * input_image.cols; - OCRPredictResult res; + dims[2] = input_image.rows; + dims[3] = input_image.cols; + input.set_dims(dims); - // ctc decode - int argmax_idx; - int last_index = 0; - float score = 0.f; - int count = 0; - float max_value = 0.0f; - - for (int n = 0; n < predict_shape[1]; n++) { - argmax_idx = int(argmax(&predict_batch[n * predict_shape[2]], - &predict_batch[(n + 1) * predict_shape[2]])); - max_value = - float(*std::max_element(&predict_batch[n * predict_shape[2]], - &predict_batch[(n + 1) * predict_shape[2]])); - if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) { - score += max_value; - count += 1; - res.word_index.push_back(argmax_idx); - } - last_index = argmax_idx; - } - score /= count; - if (res.word_index.empty()) { - continue; + neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean, + scale); + + std::vector results = _rec_predictor->infer(); + const float *predict_batch = results.at(0).get_float_data(); + const std::vector predict_shape = results.at(0).get_shape(); + + // ctc decode + int argmax_idx; + int last_index = 0; + float score = 0.f; + int count = 0; + float max_value = 0.0f; + + for (int n = 0; n < predict_shape[1]; n++) { + argmax_idx = int(argmax(&predict_batch[n * predict_shape[2]], + &predict_batch[(n + 1) * predict_shape[2]])); + max_value = + float(*std::max_element(&predict_batch[n * predict_shape[2]], + &predict_batch[(n + 1) * predict_shape[2]])); + if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) { + score += max_value; + count += 1; + ocr_result.word_index.push_back(argmax_idx); } - res.score = score; - res.points = box; - ocr_results.emplace_back(std::move(res)); + last_index = argmax_idx; } - LOGI("ocr_results finished %lu", ocr_results.size()); - return ocr_results; + score /= count; + ocr_result.score = score; + LOGI("ocr cpp rec word size %ld", count); } -cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { +ClsPredictResult OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { std::vector mean = {0.5f, 0.5f, 0.5f}; std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; std::vector dims = {1, 3, 0, 0}; - std::vector ocr_results; PredictorInput input = _cls_predictor->get_first_input(); @@ -182,7 +266,7 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { float score = 0; int label = 0; for (int64_t i = 0; i < results.at(0).get_size(); i++) { - LOGI("output scores [%f]", scores[i]); + LOGI("ocr cpp cls output scores [%f]", scores[i]); if (scores[i] > score) { score = scores[i]; label = i; @@ -193,7 +277,12 @@ cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) { if (label % 2 == 1 && score > thresh) { cv::rotate(srcimg, srcimg, 1); } - return srcimg; + ClsPredictResult res; + res.cls_label = label; + res.cls_score = score; + res.img = srcimg; + LOGI("ocr cpp cls word cls %ld, %f", label, score); + return res; } std::vector>> diff --git a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h index 588f25cb0662b6cc5c8677549e7eb3527d74fa9b..f0bff93f1fd621ed2ece62cd8e656a429a77803b 100644 --- a/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h +++ b/deploy/android_demo/app/src/main/cpp/ocr_ppredictor.h @@ -15,7 +15,8 @@ namespace ppredictor { * Config */ struct OCR_Config { - int thread_num = 4; // Thread num + int use_opencl = 0; + int thread_num = 4; // Thread num paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode }; @@ -27,8 +28,15 @@ struct OCRPredictResult { std::vector word_index; std::vector> points; float score; + float cls_score; + int cls_label=-1; }; +struct ClsPredictResult { + float cls_score; + int cls_label=-1; + cv::Mat img; +}; /** * OCR there are 2 models * 1. First model(det),select polygones to show where are the texts @@ -62,8 +70,7 @@ public: * @return */ virtual std::vector - infer_ocr(const std::vector &dims, const float *input_data, - int input_len, int net_flag, cv::Mat &origin); + infer_ocr(cv::Mat &origin, int max_size_len, int run_det, int run_cls, int run_rec); virtual NET_TYPE get_net_flag() const; @@ -80,25 +87,26 @@ private: calc_filtered_boxes(const float *pred, int pred_size, int output_height, int output_width, const cv::Mat &origin); + void + infer_det(cv::Mat &origin, int max_side_len, std::vector& ocr_results); /** - * infer for second model + * infer for rec model * * @param boxes * @param origin * @return */ - std::vector - infer_rec(const std::vector>> &boxes, - const cv::Mat &origin); + void + infer_rec(const cv::Mat &origin, int run_cls, OCRPredictResult& ocr_result); - /** + /** * infer for cls model * * @param boxes * @param origin * @return */ - cv::Mat infer_cls(const cv::Mat &origin, float thresh = 0.9); + ClsPredictResult infer_cls(const cv::Mat &origin, float thresh = 0.9); /** * Postprocess or sencod model to extract text diff --git a/deploy/android_demo/app/src/main/cpp/ppredictor.cpp b/deploy/android_demo/app/src/main/cpp/ppredictor.cpp index dcbc76911fca43468768910718b1f1ee7e81bb13..a40fe5e1b289d1c7646ded302f3faba179b1ac2e 100644 --- a/deploy/android_demo/app/src/main/cpp/ppredictor.cpp +++ b/deploy/android_demo/app/src/main/cpp/ppredictor.cpp @@ -2,9 +2,9 @@ #include "common.h" namespace ppredictor { -PPredictor::PPredictor(int thread_num, int net_flag, +PPredictor::PPredictor(int use_opencl, int thread_num, int net_flag, paddle::lite_api::PowerMode mode) - : _thread_num(thread_num), _net_flag(net_flag), _mode(mode) {} + : _use_opencl(use_opencl), _thread_num(thread_num), _net_flag(net_flag), _mode(mode) {} int PPredictor::init_nb(const std::string &model_content) { paddle::lite_api::MobileConfig config; @@ -19,10 +19,40 @@ int PPredictor::init_from_file(const std::string &model_content) { } template int PPredictor::_init(ConfigT &config) { + bool is_opencl_backend_valid = paddle::lite_api::IsOpenCLBackendValid(/*check_fp16_valid = false*/); + if (is_opencl_backend_valid) { + if (_use_opencl != 0) { + // Make sure you have write permission of the binary path. + // We strongly recommend each model has a unique binary name. + const std::string bin_path = "/data/local/tmp/"; + const std::string bin_name = "lite_opencl_kernel.bin"; + config.set_opencl_binary_path_name(bin_path, bin_name); + + // opencl tune option + // CL_TUNE_NONE: 0 + // CL_TUNE_RAPID: 1 + // CL_TUNE_NORMAL: 2 + // CL_TUNE_EXHAUSTIVE: 3 + const std::string tuned_path = "/data/local/tmp/"; + const std::string tuned_name = "lite_opencl_tuned.bin"; + config.set_opencl_tune(paddle::lite_api::CL_TUNE_NORMAL, tuned_path, tuned_name); + + // opencl precision option + // CL_PRECISION_AUTO: 0, first fp16 if valid, default + // CL_PRECISION_FP32: 1, force fp32 + // CL_PRECISION_FP16: 2, force fp16 + config.set_opencl_precision(paddle::lite_api::CL_PRECISION_FP32); + LOGI("ocr cpp device: running on gpu."); + } + } else { + LOGI("ocr cpp device: running on cpu."); + // you can give backup cpu nb model instead + // config.set_model_from_file(cpu_nb_model_dir); + } config.set_threads(_thread_num); config.set_power_mode(_mode); _predictor = paddle::lite_api::CreatePaddlePredictor(config); - LOGI("paddle instance created"); + LOGI("ocr cpp paddle instance created"); return RETURN_OK; } @@ -43,18 +73,18 @@ std::vector PPredictor::get_inputs(int num) { PredictorInput PPredictor::get_first_input() { return get_input(0); } std::vector PPredictor::infer() { - LOGI("infer Run start %d", _net_flag); + LOGI("ocr cpp infer Run start %d", _net_flag); std::vector results; if (!_is_input_get) { return results; } _predictor->Run(); - LOGI("infer Run end"); + LOGI("ocr cpp infer Run end"); for (int i = 0; i < _predictor->GetOutputNames().size(); i++) { std::unique_ptr output_tensor = _predictor->GetOutput(i); - LOGI("output tensor[%d] size %ld", i, product(output_tensor->shape())); + LOGI("ocr cpp output tensor[%d] size %ld", i, product(output_tensor->shape())); PredictorOutput result{std::move(output_tensor), i, _net_flag}; results.emplace_back(std::move(result)); } diff --git a/deploy/android_demo/app/src/main/cpp/ppredictor.h b/deploy/android_demo/app/src/main/cpp/ppredictor.h index 836861aac9aa5bd8eb009a9e4c8138b651beeace..40250764f8ac84b53a7aa2f8696a2e4ab0909f0b 100644 --- a/deploy/android_demo/app/src/main/cpp/ppredictor.h +++ b/deploy/android_demo/app/src/main/cpp/ppredictor.h @@ -22,7 +22,7 @@ public: class PPredictor : public PPredictor_Interface { public: PPredictor( - int thread_num, int net_flag = 0, + int use_opencl, int thread_num, int net_flag = 0, paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH); virtual ~PPredictor() {} @@ -54,6 +54,7 @@ protected: template int _init(ConfigT &config); private: + int _use_opencl; int _thread_num; paddle::lite_api::PowerMode _mode; std::shared_ptr _predictor; diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java index b4ea34e2a38f91f3ecb1001c6bff3b71496b8f91..f932718dcab99e59e63ccf341c35de9c547926cb 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MainActivity.java @@ -13,6 +13,7 @@ import android.graphics.BitmapFactory; import android.graphics.drawable.BitmapDrawable; import android.media.ExifInterface; import android.content.res.AssetManager; +import android.media.FaceDetector; import android.net.Uri; import android.os.Bundle; import android.os.Environment; @@ -27,7 +28,9 @@ import android.view.Menu; import android.view.MenuInflater; import android.view.MenuItem; import android.view.View; +import android.widget.CheckBox; import android.widget.ImageView; +import android.widget.Spinner; import android.widget.TextView; import android.widget.Toast; @@ -68,23 +71,24 @@ public class MainActivity extends AppCompatActivity { protected ImageView ivInputImage; protected TextView tvOutputResult; protected TextView tvInferenceTime; + protected CheckBox cbOpencl; + protected Spinner spRunMode; - // Model settings of object detection + // Model settings of ocr protected String modelPath = ""; protected String labelPath = ""; protected String imagePath = ""; protected int cpuThreadNum = 1; protected String cpuPowerMode = ""; - protected String inputColorFormat = ""; - protected long[] inputShape = new long[]{}; - protected float[] inputMean = new float[]{}; - protected float[] inputStd = new float[]{}; + protected int detLongSize = 960; protected float scoreThreshold = 0.1f; private String currentPhotoPath; - private AssetManager assetManager =null; + private AssetManager assetManager = null; protected Predictor predictor = new Predictor(); + private Bitmap cur_predict_image = null; + @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); @@ -98,10 +102,12 @@ public class MainActivity extends AppCompatActivity { // Setup the UI components tvInputSetting = findViewById(R.id.tv_input_setting); + cbOpencl = findViewById(R.id.cb_opencl); tvStatus = findViewById(R.id.tv_model_img_status); ivInputImage = findViewById(R.id.iv_input_image); tvInferenceTime = findViewById(R.id.tv_inference_time); tvOutputResult = findViewById(R.id.tv_output_result); + spRunMode = findViewById(R.id.sp_run_mode); tvInputSetting.setMovementMethod(ScrollingMovementMethod.getInstance()); tvOutputResult.setMovementMethod(ScrollingMovementMethod.getInstance()); @@ -111,26 +117,26 @@ public class MainActivity extends AppCompatActivity { public void handleMessage(Message msg) { switch (msg.what) { case RESPONSE_LOAD_MODEL_SUCCESSED: - if(pbLoadModel!=null && pbLoadModel.isShowing()){ + if (pbLoadModel != null && pbLoadModel.isShowing()) { pbLoadModel.dismiss(); } onLoadModelSuccessed(); break; case RESPONSE_LOAD_MODEL_FAILED: - if(pbLoadModel!=null && pbLoadModel.isShowing()){ + if (pbLoadModel != null && pbLoadModel.isShowing()) { pbLoadModel.dismiss(); } Toast.makeText(MainActivity.this, "Load model failed!", Toast.LENGTH_SHORT).show(); onLoadModelFailed(); break; case RESPONSE_RUN_MODEL_SUCCESSED: - if(pbRunModel!=null && pbRunModel.isShowing()){ + if (pbRunModel != null && pbRunModel.isShowing()) { pbRunModel.dismiss(); } onRunModelSuccessed(); break; case RESPONSE_RUN_MODEL_FAILED: - if(pbRunModel!=null && pbRunModel.isShowing()){ + if (pbRunModel != null && pbRunModel.isShowing()) { pbRunModel.dismiss(); } Toast.makeText(MainActivity.this, "Run model failed!", Toast.LENGTH_SHORT).show(); @@ -175,71 +181,47 @@ public class MainActivity extends AppCompatActivity { super.onResume(); SharedPreferences sharedPreferences = PreferenceManager.getDefaultSharedPreferences(this); boolean settingsChanged = false; + boolean model_settingsChanged = false; String model_path = sharedPreferences.getString(getString(R.string.MODEL_PATH_KEY), getString(R.string.MODEL_PATH_DEFAULT)); String label_path = sharedPreferences.getString(getString(R.string.LABEL_PATH_KEY), getString(R.string.LABEL_PATH_DEFAULT)); String image_path = sharedPreferences.getString(getString(R.string.IMAGE_PATH_KEY), getString(R.string.IMAGE_PATH_DEFAULT)); - settingsChanged |= !model_path.equalsIgnoreCase(modelPath); + model_settingsChanged |= !model_path.equalsIgnoreCase(modelPath); settingsChanged |= !label_path.equalsIgnoreCase(labelPath); settingsChanged |= !image_path.equalsIgnoreCase(imagePath); int cpu_thread_num = Integer.parseInt(sharedPreferences.getString(getString(R.string.CPU_THREAD_NUM_KEY), getString(R.string.CPU_THREAD_NUM_DEFAULT))); - settingsChanged |= cpu_thread_num != cpuThreadNum; + model_settingsChanged |= cpu_thread_num != cpuThreadNum; String cpu_power_mode = sharedPreferences.getString(getString(R.string.CPU_POWER_MODE_KEY), getString(R.string.CPU_POWER_MODE_DEFAULT)); - settingsChanged |= !cpu_power_mode.equalsIgnoreCase(cpuPowerMode); - String input_color_format = - sharedPreferences.getString(getString(R.string.INPUT_COLOR_FORMAT_KEY), - getString(R.string.INPUT_COLOR_FORMAT_DEFAULT)); - settingsChanged |= !input_color_format.equalsIgnoreCase(inputColorFormat); - long[] input_shape = - Utils.parseLongsFromString(sharedPreferences.getString(getString(R.string.INPUT_SHAPE_KEY), - getString(R.string.INPUT_SHAPE_DEFAULT)), ","); - float[] input_mean = - Utils.parseFloatsFromString(sharedPreferences.getString(getString(R.string.INPUT_MEAN_KEY), - getString(R.string.INPUT_MEAN_DEFAULT)), ","); - float[] input_std = - Utils.parseFloatsFromString(sharedPreferences.getString(getString(R.string.INPUT_STD_KEY) - , getString(R.string.INPUT_STD_DEFAULT)), ","); - settingsChanged |= input_shape.length != inputShape.length; - settingsChanged |= input_mean.length != inputMean.length; - settingsChanged |= input_std.length != inputStd.length; - if (!settingsChanged) { - for (int i = 0; i < input_shape.length; i++) { - settingsChanged |= input_shape[i] != inputShape[i]; - } - for (int i = 0; i < input_mean.length; i++) { - settingsChanged |= input_mean[i] != inputMean[i]; - } - for (int i = 0; i < input_std.length; i++) { - settingsChanged |= input_std[i] != inputStd[i]; - } - } + model_settingsChanged |= !cpu_power_mode.equalsIgnoreCase(cpuPowerMode); + + int det_long_size = Integer.parseInt(sharedPreferences.getString(getString(R.string.DET_LONG_SIZE_KEY), + getString(R.string.DET_LONG_SIZE_DEFAULT))); + settingsChanged |= det_long_size != detLongSize; float score_threshold = Float.parseFloat(sharedPreferences.getString(getString(R.string.SCORE_THRESHOLD_KEY), getString(R.string.SCORE_THRESHOLD_DEFAULT))); settingsChanged |= scoreThreshold != score_threshold; if (settingsChanged) { - modelPath = model_path; labelPath = label_path; imagePath = image_path; + detLongSize = det_long_size; + scoreThreshold = score_threshold; + set_img(); + } + if (model_settingsChanged) { + modelPath = model_path; cpuThreadNum = cpu_thread_num; cpuPowerMode = cpu_power_mode; - inputColorFormat = input_color_format; - inputShape = input_shape; - inputMean = input_mean; - inputStd = input_std; - scoreThreshold = score_threshold; // Update UI - tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\n" + "CPU" + - " Thread Num: " + Integer.toString(cpuThreadNum) + "\n" + "CPU Power Mode: " + cpuPowerMode); + tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\nOPENCL: " + cbOpencl.isChecked() + "\nCPU Thread Num: " + cpuThreadNum + "\nCPU Power Mode: " + cpuPowerMode); tvInputSetting.scrollTo(0, 0); // Reload model if configure has been changed -// loadModel(); - set_img(); + loadModel(); } } @@ -254,20 +236,28 @@ public class MainActivity extends AppCompatActivity { } public boolean onLoadModel() { - return predictor.init(MainActivity.this, modelPath, labelPath, cpuThreadNum, + if (predictor.isLoaded()) { + predictor.releaseModel(); + } + return predictor.init(MainActivity.this, modelPath, labelPath, cbOpencl.isChecked() ? 1 : 0, cpuThreadNum, cpuPowerMode, - inputColorFormat, - inputShape, inputMean, - inputStd, scoreThreshold); + detLongSize, scoreThreshold); } public boolean onRunModel() { - return predictor.isLoaded() && predictor.runModel(); + String run_mode = spRunMode.getSelectedItem().toString(); + int run_det = run_mode.contains("检测") ? 1 : 0; + int run_cls = run_mode.contains("分类") ? 1 : 0; + int run_rec = run_mode.contains("识别") ? 1 : 0; + return predictor.isLoaded() && predictor.runModel(run_det, run_cls, run_rec); } public void onLoadModelSuccessed() { // Load test image from path and run model + tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\nOPENCL: " + cbOpencl.isChecked() + "\nCPU Thread Num: " + cpuThreadNum + "\nCPU Power Mode: " + cpuPowerMode); + tvInputSetting.scrollTo(0, 0); tvStatus.setText("STATUS: load model successed"); + } public void onLoadModelFailed() { @@ -290,20 +280,13 @@ public class MainActivity extends AppCompatActivity { tvStatus.setText("STATUS: run model failed"); } - public void onImageChanged(Bitmap image) { - // Rerun model if users pick test image from gallery or camera - if (image != null && predictor.isLoaded()) { - predictor.setInputImage(image); - runModel(); - } - } - public void set_img() { // Load test image from path and run model try { - assetManager= getAssets(); - InputStream in=assetManager.open(imagePath); - Bitmap bmp=BitmapFactory.decodeStream(in); + assetManager = getAssets(); + InputStream in = assetManager.open(imagePath); + Bitmap bmp = BitmapFactory.decodeStream(in); + cur_predict_image = bmp; ivInputImage.setImageBitmap(bmp); } catch (IOException e) { Toast.makeText(MainActivity.this, "Load image failed!", Toast.LENGTH_SHORT).show(); @@ -430,7 +413,7 @@ public class MainActivity extends AppCompatActivity { Cursor cursor = managedQuery(uri, proj, null, null, null); cursor.moveToFirst(); if (image != null) { -// onImageChanged(image); + cur_predict_image = image; ivInputImage.setImageBitmap(image); } } catch (IOException e) { @@ -451,7 +434,7 @@ public class MainActivity extends AppCompatActivity { Bitmap image = BitmapFactory.decodeFile(currentPhotoPath); image = Utils.rotateBitmap(image, orientation); if (image != null) { -// onImageChanged(image); + cur_predict_image = image; ivInputImage.setImageBitmap(image); } } else { @@ -464,28 +447,28 @@ public class MainActivity extends AppCompatActivity { } } - public void btn_load_model_click(View view) { - if (predictor.isLoaded()){ - tvStatus.setText("STATUS: model has been loaded"); - }else{ - tvStatus.setText("STATUS: load model ......"); - loadModel(); - } + public void btn_reset_img_click(View view) { + ivInputImage.setImageBitmap(cur_predict_image); + } + + public void cb_opencl_click(View view) { + tvStatus.setText("STATUS: load model ......"); + loadModel(); } public void btn_run_model_click(View view) { - Bitmap image =((BitmapDrawable)ivInputImage.getDrawable()).getBitmap(); - if(image == null) { + Bitmap image = ((BitmapDrawable) ivInputImage.getDrawable()).getBitmap(); + if (image == null) { tvStatus.setText("STATUS: image is not exists"); - } - else if (!predictor.isLoaded()){ + } else if (!predictor.isLoaded()) { tvStatus.setText("STATUS: model is not loaded"); - }else{ + } else { tvStatus.setText("STATUS: run model ...... "); predictor.setInputImage(image); runModel(); } } + public void btn_choice_img_click(View view) { if (requestAllPermissions()) { openGallery(); @@ -506,4 +489,32 @@ public class MainActivity extends AppCompatActivity { worker.quit(); super.onDestroy(); } + + public int get_run_mode() { + String run_mode = spRunMode.getSelectedItem().toString(); + int mode; + switch (run_mode) { + case "检测+分类+识别": + mode = 1; + break; + case "检测+识别": + mode = 2; + break; + case "识别+分类": + mode = 3; + break; + case "检测": + mode = 4; + break; + case "识别": + mode = 5; + break; + case "分类": + mode = 6; + break; + default: + mode = 1; + } + return mode; + } } diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MiniActivity.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MiniActivity.java deleted file mode 100644 index 38c98294717e5446b4a3d245e188ecadc56ee7b5..0000000000000000000000000000000000000000 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/MiniActivity.java +++ /dev/null @@ -1,157 +0,0 @@ -package com.baidu.paddle.lite.demo.ocr; - -import android.graphics.Bitmap; -import android.graphics.BitmapFactory; -import android.os.Build; -import android.os.Bundle; -import android.os.Handler; -import android.os.HandlerThread; -import android.os.Message; -import android.util.Log; -import android.view.View; -import android.widget.Button; -import android.widget.ImageView; -import android.widget.TextView; -import android.widget.Toast; - -import androidx.appcompat.app.AppCompatActivity; - -import java.io.IOException; -import java.io.InputStream; - -public class MiniActivity extends AppCompatActivity { - - - public static final int REQUEST_LOAD_MODEL = 0; - public static final int REQUEST_RUN_MODEL = 1; - public static final int REQUEST_UNLOAD_MODEL = 2; - public static final int RESPONSE_LOAD_MODEL_SUCCESSED = 0; - public static final int RESPONSE_LOAD_MODEL_FAILED = 1; - public static final int RESPONSE_RUN_MODEL_SUCCESSED = 2; - public static final int RESPONSE_RUN_MODEL_FAILED = 3; - - private static final String TAG = "MiniActivity"; - - protected Handler receiver = null; // Receive messages from worker thread - protected Handler sender = null; // Send command to worker thread - protected HandlerThread worker = null; // Worker thread to load&run model - protected volatile Predictor predictor = null; - - private String assetModelDirPath = "models/ocr_v2_for_cpu"; - private String assetlabelFilePath = "labels/ppocr_keys_v1.txt"; - - private Button button; - private ImageView imageView; // image result - private TextView textView; // text result - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - setContentView(R.layout.activity_mini); - - Log.i(TAG, "SHOW in Logcat"); - - // Prepare the worker thread for mode loading and inference - worker = new HandlerThread("Predictor Worker"); - worker.start(); - sender = new Handler(worker.getLooper()) { - public void handleMessage(Message msg) { - switch (msg.what) { - case REQUEST_LOAD_MODEL: - // Load model and reload test image - if (!onLoadModel()) { - runOnUiThread(new Runnable() { - @Override - public void run() { - Toast.makeText(MiniActivity.this, "Load model failed!", Toast.LENGTH_SHORT).show(); - } - }); - } - break; - case REQUEST_RUN_MODEL: - // Run model if model is loaded - final boolean isSuccessed = onRunModel(); - runOnUiThread(new Runnable() { - @Override - public void run() { - if (isSuccessed){ - onRunModelSuccessed(); - }else{ - Toast.makeText(MiniActivity.this, "Run model failed!", Toast.LENGTH_SHORT).show(); - } - } - }); - break; - } - } - }; - sender.sendEmptyMessage(REQUEST_LOAD_MODEL); // corresponding to REQUEST_LOAD_MODEL, to call onLoadModel() - - imageView = findViewById(R.id.imageView); - textView = findViewById(R.id.sample_text); - button = findViewById(R.id.button); - button.setOnClickListener(new View.OnClickListener() { - @Override - public void onClick(View v) { - sender.sendEmptyMessage(REQUEST_RUN_MODEL); - } - }); - - - } - - @Override - protected void onDestroy() { - onUnloadModel(); - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.JELLY_BEAN_MR2) { - worker.quitSafely(); - } else { - worker.quit(); - } - super.onDestroy(); - } - - /** - * call in onCreate, model init - * - * @return - */ - private boolean onLoadModel() { - if (predictor == null) { - predictor = new Predictor(); - } - return predictor.init(this, assetModelDirPath, assetlabelFilePath); - } - - /** - * init engine - * call in onCreate - * - * @return - */ - private boolean onRunModel() { - try { - String assetImagePath = "images/0.jpg"; - InputStream imageStream = getAssets().open(assetImagePath); - Bitmap image = BitmapFactory.decodeStream(imageStream); - // Input is Bitmap - predictor.setInputImage(image); - return predictor.isLoaded() && predictor.runModel(); - } catch (IOException e) { - e.printStackTrace(); - return false; - } - } - - private void onRunModelSuccessed() { - Log.i(TAG, "onRunModelSuccessed"); - textView.setText(predictor.outputResult); - imageView.setImageBitmap(predictor.outputImage); - } - - private void onUnloadModel() { - if (predictor != null) { - predictor.releaseModel(); - } - } -} diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java index 1fa419e32a4cbefc27fc687b376ecc7e6a1e8a2f..622da2a3f9a1233167e777e62b687c1f246df01f 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java @@ -29,22 +29,22 @@ public class OCRPredictorNative { public OCRPredictorNative(Config config) { this.config = config; loadLibrary(); - nativePointer = init(config.detModelFilename, config.recModelFilename,config.clsModelFilename, + nativePointer = init(config.detModelFilename, config.recModelFilename, config.clsModelFilename, config.useOpencl, config.cpuThreadNum, config.cpuPower); Log.i("OCRPredictorNative", "load success " + nativePointer); } - public ArrayList runImage(float[] inputData, int width, int height, int channels, Bitmap originalImage) { - Log.i("OCRPredictorNative", "begin to run image " + inputData.length + " " + width + " " + height); - float[] dims = new float[]{1, channels, height, width}; - float[] rawResults = forward(nativePointer, inputData, dims, originalImage); + public ArrayList runImage(Bitmap originalImage, int max_size_len, int run_det, int run_cls, int run_rec) { + Log.i("OCRPredictorNative", "begin to run image "); + float[] rawResults = forward(nativePointer, originalImage, max_size_len, run_det, run_cls, run_rec); ArrayList results = postprocess(rawResults); return results; } public static class Config { + public int useOpencl; public int cpuThreadNum; public String cpuPower; public String detModelFilename; @@ -53,16 +53,16 @@ public class OCRPredictorNative { } - public void destory(){ + public void destory() { if (nativePointer > 0) { release(nativePointer); nativePointer = 0; } } - protected native long init(String detModelPath, String recModelPath,String clsModelPath, int threadNum, String cpuMode); + protected native long init(String detModelPath, String recModelPath, String clsModelPath, int useOpencl, int threadNum, String cpuMode); - protected native float[] forward(long pointer, float[] buf, float[] ddims, Bitmap originalImage); + protected native float[] forward(long pointer, Bitmap originalImage,int max_size_len, int run_det, int run_cls, int run_rec); protected native void release(long pointer); @@ -73,9 +73,9 @@ public class OCRPredictorNative { while (begin < raw.length) { int point_num = Math.round(raw[begin]); int word_num = Math.round(raw[begin + 1]); - OcrResultModel model = parse(raw, begin + 2, point_num, word_num); - begin += 2 + 1 + point_num * 2 + word_num; - results.add(model); + OcrResultModel res = parse(raw, begin + 2, point_num, word_num); + begin += 2 + 1 + point_num * 2 + word_num + 2; + results.add(res); } return results; @@ -83,19 +83,22 @@ public class OCRPredictorNative { private OcrResultModel parse(float[] raw, int begin, int pointNum, int wordNum) { int current = begin; - OcrResultModel model = new OcrResultModel(); - model.setConfidence(raw[current]); + OcrResultModel res = new OcrResultModel(); + res.setConfidence(raw[current]); current++; for (int i = 0; i < pointNum; i++) { - model.addPoints(Math.round(raw[current + i * 2]), Math.round(raw[current + i * 2 + 1])); + res.addPoints(Math.round(raw[current + i * 2]), Math.round(raw[current + i * 2 + 1])); } current += (pointNum * 2); for (int i = 0; i < wordNum; i++) { int index = Math.round(raw[current + i]); - model.addWordIndex(index); + res.addWordIndex(index); } + current += wordNum; + res.setClsIdx(raw[current]); + res.setClsConfidence(raw[current + 1]); Log.i("OCRPredictorNative", "word finished " + wordNum); - return model; + return res; } diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OcrResultModel.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OcrResultModel.java index 9494574e0db2d0c5d11ce4ad1cb400710c71a623..1bccbc7d51c905642d2b031706370f0ab9f50afc 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OcrResultModel.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OcrResultModel.java @@ -10,6 +10,9 @@ public class OcrResultModel { private List wordIndex; private String label; private float confidence; + private float cls_idx; + private String cls_label; + private float cls_confidence; public OcrResultModel() { super(); @@ -49,4 +52,28 @@ public class OcrResultModel { public void setConfidence(float confidence) { this.confidence = confidence; } + + public float getClsIdx() { + return cls_idx; + } + + public void setClsIdx(float idx) { + this.cls_idx = idx; + } + + public String getClsLabel() { + return cls_label; + } + + public void setClsLabel(String label) { + this.cls_label = label; + } + + public float getClsConfidence() { + return cls_confidence; + } + + public void setClsConfidence(float confidence) { + this.cls_confidence = confidence; + } } diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java index 8bcd79b95b322a38dcd56d6ffe3a203d3d1ea6ae..ab312161d14ba2c1cdbad278b248bd68b042ed39 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/Predictor.java @@ -31,23 +31,19 @@ public class Predictor { protected float inferenceTime = 0; // Only for object detection protected Vector wordLabels = new Vector(); - protected String inputColorFormat = "BGR"; - protected long[] inputShape = new long[]{1, 3, 960}; - protected float[] inputMean = new float[]{0.485f, 0.456f, 0.406f}; - protected float[] inputStd = new float[]{1.0f / 0.229f, 1.0f / 0.224f, 1.0f / 0.225f}; + protected int detLongSize = 960; protected float scoreThreshold = 0.1f; protected Bitmap inputImage = null; protected Bitmap outputImage = null; protected volatile String outputResult = ""; - protected float preprocessTime = 0; protected float postprocessTime = 0; public Predictor() { } - public boolean init(Context appCtx, String modelPath, String labelPath) { - isLoaded = loadModel(appCtx, modelPath, cpuThreadNum, cpuPowerMode); + public boolean init(Context appCtx, String modelPath, String labelPath, int useOpencl, int cpuThreadNum, String cpuPowerMode) { + isLoaded = loadModel(appCtx, modelPath, useOpencl, cpuThreadNum, cpuPowerMode); if (!isLoaded) { return false; } @@ -56,49 +52,18 @@ public class Predictor { } - public boolean init(Context appCtx, String modelPath, String labelPath, int cpuThreadNum, String cpuPowerMode, - String inputColorFormat, - long[] inputShape, float[] inputMean, - float[] inputStd, float scoreThreshold) { - if (inputShape.length != 3) { - Log.e(TAG, "Size of input shape should be: 3"); - return false; - } - if (inputMean.length != inputShape[1]) { - Log.e(TAG, "Size of input mean should be: " + Long.toString(inputShape[1])); - return false; - } - if (inputStd.length != inputShape[1]) { - Log.e(TAG, "Size of input std should be: " + Long.toString(inputShape[1])); - return false; - } - if (inputShape[0] != 1) { - Log.e(TAG, "Only one batch is supported in the image classification demo, you can use any batch size in " + - "your Apps!"); - return false; - } - if (inputShape[1] != 1 && inputShape[1] != 3) { - Log.e(TAG, "Only one/three channels are supported in the image classification demo, you can use any " + - "channel size in your Apps!"); - return false; - } - if (!inputColorFormat.equalsIgnoreCase("BGR")) { - Log.e(TAG, "Only BGR color format is supported."); - return false; - } - boolean isLoaded = init(appCtx, modelPath, labelPath); + public boolean init(Context appCtx, String modelPath, String labelPath, int useOpencl, int cpuThreadNum, String cpuPowerMode, + int detLongSize, float scoreThreshold) { + boolean isLoaded = init(appCtx, modelPath, labelPath, useOpencl, cpuThreadNum, cpuPowerMode); if (!isLoaded) { return false; } - this.inputColorFormat = inputColorFormat; - this.inputShape = inputShape; - this.inputMean = inputMean; - this.inputStd = inputStd; + this.detLongSize = detLongSize; this.scoreThreshold = scoreThreshold; return true; } - protected boolean loadModel(Context appCtx, String modelPath, int cpuThreadNum, String cpuPowerMode) { + protected boolean loadModel(Context appCtx, String modelPath, int useOpencl, int cpuThreadNum, String cpuPowerMode) { // Release model if exists releaseModel(); @@ -118,12 +83,13 @@ public class Predictor { } OCRPredictorNative.Config config = new OCRPredictorNative.Config(); + config.useOpencl = useOpencl; config.cpuThreadNum = cpuThreadNum; - config.detModelFilename = realPath + File.separator + "ch_ppocr_mobile_v2.0_det_opt.nb"; - config.recModelFilename = realPath + File.separator + "ch_ppocr_mobile_v2.0_rec_opt.nb"; - config.clsModelFilename = realPath + File.separator + "ch_ppocr_mobile_v2.0_cls_opt.nb"; - Log.e("Predictor", "model path" + config.detModelFilename + " ; " + config.recModelFilename + ";" + config.clsModelFilename); config.cpuPower = cpuPowerMode; + config.detModelFilename = realPath + File.separator + "det_db.nb"; + config.recModelFilename = realPath + File.separator + "rec_crnn.nb"; + config.clsModelFilename = realPath + File.separator + "cls.nb"; + Log.i("Predictor", "model path" + config.detModelFilename + " ; " + config.recModelFilename + ";" + config.clsModelFilename); paddlePredictor = new OCRPredictorNative(config); this.cpuThreadNum = cpuThreadNum; @@ -170,82 +136,29 @@ public class Predictor { } - public boolean runModel() { + public boolean runModel(int run_det, int run_cls, int run_rec) { if (inputImage == null || !isLoaded()) { return false; } - // Pre-process image, and feed input tensor with pre-processed data - - Bitmap scaleImage = Utils.resizeWithStep(inputImage, Long.valueOf(inputShape[2]).intValue(), 32); - - Date start = new Date(); - int channels = (int) inputShape[1]; - int width = scaleImage.getWidth(); - int height = scaleImage.getHeight(); - float[] inputData = new float[channels * width * height]; - if (channels == 3) { - int[] channelIdx = null; - if (inputColorFormat.equalsIgnoreCase("RGB")) { - channelIdx = new int[]{0, 1, 2}; - } else if (inputColorFormat.equalsIgnoreCase("BGR")) { - channelIdx = new int[]{2, 1, 0}; - } else { - Log.i(TAG, "Unknown color format " + inputColorFormat + ", only RGB and BGR color format is " + - "supported!"); - return false; - } - - int[] channelStride = new int[]{width * height, width * height * 2}; - int[] pixels=new int[width*height]; - scaleImage.getPixels(pixels,0,scaleImage.getWidth(),0,0,scaleImage.getWidth(),scaleImage.getHeight()); - for (int i = 0; i < pixels.length; i++) { - int color = pixels[i]; - float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f, - (float) blue(color) / 255.0f}; - inputData[i] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0]; - inputData[i + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1]; - inputData[i+ channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2]; - } - } else if (channels == 1) { - int[] pixels=new int[width*height]; - scaleImage.getPixels(pixels,0,scaleImage.getWidth(),0,0,scaleImage.getWidth(),scaleImage.getHeight()); - for (int i = 0; i < pixels.length; i++) { - int color = pixels[i]; - float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f; - inputData[i] = (gray - inputMean[0]) / inputStd[0]; - } - } else { - Log.i(TAG, "Unsupported channel size " + Integer.toString(channels) + ", only channel 1 and 3 is " + - "supported!"); - return false; - } - float[] pixels = inputData; - Log.i(TAG, "pixels " + pixels[0] + " " + pixels[1] + " " + pixels[2] + " " + pixels[3] - + " " + pixels[pixels.length / 2] + " " + pixels[pixels.length / 2 + 1] + " " + pixels[pixels.length - 2] + " " + pixels[pixels.length - 1]); - Date end = new Date(); - preprocessTime = (float) (end.getTime() - start.getTime()); - // Warm up for (int i = 0; i < warmupIterNum; i++) { - paddlePredictor.runImage(inputData, width, height, channels, inputImage); + paddlePredictor.runImage(inputImage, detLongSize, run_det, run_cls, run_rec); } warmupIterNum = 0; // do not need warm // Run inference - start = new Date(); - ArrayList results = paddlePredictor.runImage(inputData, width, height, channels, inputImage); - end = new Date(); + Date start = new Date(); + ArrayList results = paddlePredictor.runImage(inputImage, detLongSize, run_det, run_cls, run_rec); + Date end = new Date(); inferenceTime = (end.getTime() - start.getTime()) / (float) inferIterNum; results = postprocess(results); - Log.i(TAG, "[stat] Preprocess Time: " + preprocessTime - + " ; Inference Time: " + inferenceTime + " ;Box Size " + results.size()); + Log.i(TAG, "[stat] Inference Time: " + inferenceTime + " ;Box Size " + results.size()); drawResults(results); return true; } - public boolean isLoaded() { return paddlePredictor != null && isLoaded; } @@ -282,10 +195,6 @@ public class Predictor { return outputResult; } - public float preprocessTime() { - return preprocessTime; - } - public float postprocessTime() { return postprocessTime; } @@ -310,6 +219,7 @@ public class Predictor { } } r.setLabel(word.toString()); + r.setClsLabel(r.getClsIdx() == 1 ? "180" : "0"); } return results; } @@ -319,14 +229,22 @@ public class Predictor { for (int i = 0; i < results.size(); i++) { OcrResultModel result = results.get(i); StringBuilder sb = new StringBuilder(""); - sb.append(result.getLabel()); - sb.append(" ").append(result.getConfidence()); - sb.append("; Points: "); - for (Point p : result.getPoints()) { - sb.append("(").append(p.x).append(",").append(p.y).append(") "); + if(result.getPoints().size()>0){ + sb.append("Det: "); + for (Point p : result.getPoints()) { + sb.append("(").append(p.x).append(",").append(p.y).append(") "); + } + } + if(result.getLabel().length() > 0){ + sb.append("\n Rec: ").append(result.getLabel()); + sb.append(",").append(result.getConfidence()); + } + if(result.getClsIdx()!=-1){ + sb.append(" Cls: ").append(result.getClsLabel()); + sb.append(",").append(result.getClsConfidence()); } Log.i(TAG, sb.toString()); // show LOG in Logcat panel - outputResultSb.append(i + 1).append(": ").append(result.getLabel()).append("\n"); + outputResultSb.append(i + 1).append(": ").append(sb.toString()).append("\n"); } outputResult = outputResultSb.toString(); outputImage = inputImage; @@ -344,6 +262,9 @@ public class Predictor { for (OcrResultModel result : results) { Path path = new Path(); List points = result.getPoints(); + if(points.size()==0){ + continue; + } path.moveTo(points.get(0).x, points.get(0).y); for (int i = points.size() - 1; i >= 0; i--) { Point p = points.get(i); diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/SettingsActivity.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/SettingsActivity.java index c26fe6ea439da989bee83d64296519da7971eda6..477cd5d8a2ed12ec41a304fcf0bdea3198f31998 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/SettingsActivity.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/SettingsActivity.java @@ -20,16 +20,13 @@ public class SettingsActivity extends AppCompatPreferenceActivity implements Sha ListPreference etImagePath = null; ListPreference lpCPUThreadNum = null; ListPreference lpCPUPowerMode = null; - ListPreference lpInputColorFormat = null; - EditTextPreference etInputShape = null; - EditTextPreference etInputMean = null; - EditTextPreference etInputStd = null; + EditTextPreference etDetLongSize = null; EditTextPreference etScoreThreshold = null; List preInstalledModelPaths = null; List preInstalledLabelPaths = null; List preInstalledImagePaths = null; - List preInstalledInputShapes = null; + List preInstalledDetLongSizes = null; List preInstalledCPUThreadNums = null; List preInstalledCPUPowerModes = null; List preInstalledInputColorFormats = null; @@ -50,7 +47,7 @@ public class SettingsActivity extends AppCompatPreferenceActivity implements Sha preInstalledModelPaths = new ArrayList(); preInstalledLabelPaths = new ArrayList(); preInstalledImagePaths = new ArrayList(); - preInstalledInputShapes = new ArrayList(); + preInstalledDetLongSizes = new ArrayList(); preInstalledCPUThreadNums = new ArrayList(); preInstalledCPUPowerModes = new ArrayList(); preInstalledInputColorFormats = new ArrayList(); @@ -63,10 +60,7 @@ public class SettingsActivity extends AppCompatPreferenceActivity implements Sha preInstalledImagePaths.add(getString(R.string.IMAGE_PATH_DEFAULT)); preInstalledCPUThreadNums.add(getString(R.string.CPU_THREAD_NUM_DEFAULT)); preInstalledCPUPowerModes.add(getString(R.string.CPU_POWER_MODE_DEFAULT)); - preInstalledInputColorFormats.add(getString(R.string.INPUT_COLOR_FORMAT_DEFAULT)); - preInstalledInputShapes.add(getString(R.string.INPUT_SHAPE_DEFAULT)); - preInstalledInputMeans.add(getString(R.string.INPUT_MEAN_DEFAULT)); - preInstalledInputStds.add(getString(R.string.INPUT_STD_DEFAULT)); + preInstalledDetLongSizes.add(getString(R.string.DET_LONG_SIZE_DEFAULT)); preInstalledScoreThresholds.add(getString(R.string.SCORE_THRESHOLD_DEFAULT)); // Setup UI components @@ -89,11 +83,7 @@ public class SettingsActivity extends AppCompatPreferenceActivity implements Sha (ListPreference) findPreference(getString(R.string.CPU_THREAD_NUM_KEY)); lpCPUPowerMode = (ListPreference) findPreference(getString(R.string.CPU_POWER_MODE_KEY)); - lpInputColorFormat = - (ListPreference) findPreference(getString(R.string.INPUT_COLOR_FORMAT_KEY)); - etInputShape = (EditTextPreference) findPreference(getString(R.string.INPUT_SHAPE_KEY)); - etInputMean = (EditTextPreference) findPreference(getString(R.string.INPUT_MEAN_KEY)); - etInputStd = (EditTextPreference) findPreference(getString(R.string.INPUT_STD_KEY)); + etDetLongSize = (EditTextPreference) findPreference(getString(R.string.DET_LONG_SIZE_KEY)); etScoreThreshold = (EditTextPreference) findPreference(getString(R.string.SCORE_THRESHOLD_KEY)); } @@ -112,11 +102,7 @@ public class SettingsActivity extends AppCompatPreferenceActivity implements Sha editor.putString(getString(R.string.IMAGE_PATH_KEY), preInstalledImagePaths.get(modelIdx)); editor.putString(getString(R.string.CPU_THREAD_NUM_KEY), preInstalledCPUThreadNums.get(modelIdx)); editor.putString(getString(R.string.CPU_POWER_MODE_KEY), preInstalledCPUPowerModes.get(modelIdx)); - editor.putString(getString(R.string.INPUT_COLOR_FORMAT_KEY), - preInstalledInputColorFormats.get(modelIdx)); - editor.putString(getString(R.string.INPUT_SHAPE_KEY), preInstalledInputShapes.get(modelIdx)); - editor.putString(getString(R.string.INPUT_MEAN_KEY), preInstalledInputMeans.get(modelIdx)); - editor.putString(getString(R.string.INPUT_STD_KEY), preInstalledInputStds.get(modelIdx)); + editor.putString(getString(R.string.DET_LONG_SIZE_KEY), preInstalledDetLongSizes.get(modelIdx)); editor.putString(getString(R.string.SCORE_THRESHOLD_KEY), preInstalledScoreThresholds.get(modelIdx)); editor.apply(); @@ -129,10 +115,7 @@ public class SettingsActivity extends AppCompatPreferenceActivity implements Sha etImagePath.setEnabled(enableCustomSettings); lpCPUThreadNum.setEnabled(enableCustomSettings); lpCPUPowerMode.setEnabled(enableCustomSettings); - lpInputColorFormat.setEnabled(enableCustomSettings); - etInputShape.setEnabled(enableCustomSettings); - etInputMean.setEnabled(enableCustomSettings); - etInputStd.setEnabled(enableCustomSettings); + etDetLongSize.setEnabled(enableCustomSettings); etScoreThreshold.setEnabled(enableCustomSettings); modelPath = sharedPreferences.getString(getString(R.string.MODEL_PATH_KEY), getString(R.string.MODEL_PATH_DEFAULT)); @@ -144,14 +127,8 @@ public class SettingsActivity extends AppCompatPreferenceActivity implements Sha getString(R.string.CPU_THREAD_NUM_DEFAULT)); String cpuPowerMode = sharedPreferences.getString(getString(R.string.CPU_POWER_MODE_KEY), getString(R.string.CPU_POWER_MODE_DEFAULT)); - String inputColorFormat = sharedPreferences.getString(getString(R.string.INPUT_COLOR_FORMAT_KEY), - getString(R.string.INPUT_COLOR_FORMAT_DEFAULT)); - String inputShape = sharedPreferences.getString(getString(R.string.INPUT_SHAPE_KEY), - getString(R.string.INPUT_SHAPE_DEFAULT)); - String inputMean = sharedPreferences.getString(getString(R.string.INPUT_MEAN_KEY), - getString(R.string.INPUT_MEAN_DEFAULT)); - String inputStd = sharedPreferences.getString(getString(R.string.INPUT_STD_KEY), - getString(R.string.INPUT_STD_DEFAULT)); + String detLongSize = sharedPreferences.getString(getString(R.string.DET_LONG_SIZE_KEY), + getString(R.string.DET_LONG_SIZE_DEFAULT)); String scoreThreshold = sharedPreferences.getString(getString(R.string.SCORE_THRESHOLD_KEY), getString(R.string.SCORE_THRESHOLD_DEFAULT)); etModelPath.setSummary(modelPath); @@ -164,14 +141,8 @@ public class SettingsActivity extends AppCompatPreferenceActivity implements Sha lpCPUThreadNum.setSummary(cpuThreadNum); lpCPUPowerMode.setValue(cpuPowerMode); lpCPUPowerMode.setSummary(cpuPowerMode); - lpInputColorFormat.setValue(inputColorFormat); - lpInputColorFormat.setSummary(inputColorFormat); - etInputShape.setSummary(inputShape); - etInputShape.setText(inputShape); - etInputMean.setSummary(inputMean); - etInputMean.setText(inputMean); - etInputStd.setSummary(inputStd); - etInputStd.setText(inputStd); + etDetLongSize.setSummary(detLongSize); + etDetLongSize.setText(detLongSize); etScoreThreshold.setText(scoreThreshold); etScoreThreshold.setSummary(scoreThreshold); } diff --git a/deploy/android_demo/app/src/main/res/layout/activity_main.xml b/deploy/android_demo/app/src/main/res/layout/activity_main.xml index 5caf568ee8191f0383381c492bee0e48d77401c3..e90c99a68b0867d881925198d91c4bfcbbc22e8b 100644 --- a/deploy/android_demo/app/src/main/res/layout/activity_main.xml +++ b/deploy/android_demo/app/src/main/res/layout/activity_main.xml @@ -23,13 +23,7 @@ android:layout_height="wrap_content" android:orientation="horizontal"> -