提交 16a2ab58 编写于 作者: L LKKlein

fix default download path; auto-download dict from http

上级 62d302ff
......@@ -261,7 +261,7 @@ func (ocr *OCRSystem) StartServer(port string)
当前给定的配置文件`config/conf.yaml`中,包含了默认的OCR预测配置参数,可根据个人需要更改相关参数。
比如,将`use_gpu`改为`false`,使用CPU执行预测;将`det_model_dir`, `rec_model_dir`, `cls_model_dir`都更改为自己的本地模型路径,也或者是更改字典`rec_char_dict_path`的路径配置参数包含了预测引擎、检测模型、检测阈值、方向分类模型、识别模型及阈值的相关参数,具体参数的意义可参见[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/whl.md#%E5%8F%82%E6%95%B0%E8%AF%B4%E6%98%8E)
比如,将`use_gpu`改为`false`,使用CPU执行预测;将`det_model_dir`, `rec_model_dir`, `cls_model_dir`都更改为自己的本地模型路径,也或者是更改字典`rec_char_dict_path`的路径,这四个路径如果配置http链接,会自动下载到本地目录。另外,配置参数包含了预测引擎、检测模型、检测阈值、方向分类模型、识别模型及阈值的相关参数,具体参数的意义可参见[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/whl.md#%E5%8F%82%E6%95%B0%E8%AF%B4%E6%98%8E)
### 3.2 编译预测demo
......
......@@ -30,7 +30,7 @@ rec_image_shape: [3, 32, 320]
rec_char_type: "ch"
rec_batch_num: 30
max_text_length: 25
rec_char_dict_path: "config/ppocr_keys_v1.txt"
rec_char_dict_path: "https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/develop/ppocr/utils/ppocr_keys_v1.txt"
use_space_char: true
# params for text classifier
......
......@@ -28,7 +28,7 @@ var (
"rec_char_type": "ch",
"rec_batch_num": 30,
"max_text_length": 25,
"rec_char_dict_path": "config/ppocr_keys_v1.txt",
"rec_char_dict_path": "https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/develop/ppocr/utils/ppocr_keys_v1.txt",
"use_space_char": true,
"use_angle_cls": false,
......
......@@ -2,6 +2,7 @@ package ocr
import (
"log"
"os"
"time"
"github.com/LKKlein/gocv"
......@@ -34,7 +35,8 @@ func NewTextClassifier(modelDir string, args map[string]interface{}) *TextClassi
shape: shapes,
}
if checkModelExists(modelDir) {
modelDir, _ = downloadModel("./inference/cls", modelDir)
home, _ := os.UserHomeDir()
modelDir, _ = downloadModel(home+"/.paddleocr/cls", modelDir)
} else {
log.Panicf("cls model path: %v not exist! Please check!", modelDir)
}
......
......@@ -2,6 +2,7 @@ package ocr
import (
"log"
"os"
"time"
"github.com/LKKlein/gocv"
......@@ -25,7 +26,8 @@ func NewDBDetector(modelDir string, args map[string]interface{}) *DBDetector {
postProcess: NewDBPostProcess(thresh, boxThresh, unClipRatio),
}
if checkModelExists(modelDir) {
modelDir, _ = downloadModel("./inference/det", modelDir)
home, _ := os.UserHomeDir()
modelDir, _ = downloadModel(home+"/.paddleocr/det", modelDir)
} else {
log.Panicf("det model path: %v not exist! Please check!", modelDir)
}
......
......@@ -2,6 +2,7 @@ package ocr
import (
"log"
"os"
"time"
"github.com/LKKlein/gocv"
......@@ -37,7 +38,8 @@ func NewTextRecognizer(modelDir string, args map[string]interface{}) *TextRecogn
labels: labels,
}
if checkModelExists(modelDir) {
modelDir, _ = downloadModel("./inference/rec/ch", modelDir)
home, _ := os.UserHomeDir()
modelDir, _ = downloadModel(home+"/.paddleocr/rec/ch", modelDir)
} else {
log.Panicf("rec model path: %v not exist! Please check!", modelDir)
}
......
......@@ -9,7 +9,6 @@ import (
"os"
"path"
"path/filepath"
"regexp"
"strings"
"github.com/LKKlein/gocv"
......@@ -154,8 +153,10 @@ func isPathExist(path string) bool {
func downloadModel(modelDir, modelPath string) (string, error) {
if modelPath != "" && (strings.HasPrefix(modelPath, "http://") ||
strings.HasPrefix(modelPath, "ftp://") || strings.HasPrefix(modelPath, "https://")) {
reg := regexp.MustCompile("^(http|https|ftp)://[^/]+/(.+)")
suffix := reg.FindStringSubmatch(modelPath)[2]
if checkModelExists(modelDir) {
return modelDir, nil
}
_, suffix := path.Split(modelPath)
outPath := filepath.Join(modelDir, suffix)
outDir := filepath.Dir(outPath)
if !isPathExist(outDir) {
......@@ -168,16 +169,13 @@ func downloadModel(modelDir, modelPath string) (string, error) {
return "", err
}
}
if strings.HasSuffix(outPath, ".tar") {
_, f := path.Split(suffix)
nextDir := strings.TrimSuffix(f, ".tar")
finalPath := modelDir + "/" + nextDir
if !checkModelExists(finalPath) {
unTar(modelDir, outPath)
}
return finalPath, nil
if strings.HasSuffix(outPath, ".tar") && !checkModelExists(modelDir) {
unTar(modelDir, outPath)
os.Remove(outPath)
return modelDir, nil
}
return outPath, nil
return modelDir, nil
}
return modelPath, nil
}
......@@ -202,15 +200,16 @@ func unTar(dst, src string) (err error) {
continue
}
dstFileDir := filepath.Join(dst, hdr.Name)
var dstFileDir string
if strings.Contains(hdr.Name, "model") {
dstFileDir = filepath.Join(dst, "model")
} else if strings.Contains(hdr.Name, "params") {
dstFileDir = filepath.Join(dst, "params")
}
switch hdr.Typeflag {
case tar.TypeDir:
if b := isPathExist(dstFileDir); !b {
if err := os.MkdirAll(dstFileDir, 0775); err != nil {
return err
}
}
continue
case tar.TypeReg:
file, err := os.OpenFile(dstFileDir, os.O_CREATE|os.O_RDWR, os.FileMode(hdr.Mode))
if err != nil {
......@@ -227,10 +226,24 @@ func unTar(dst, src string) (err error) {
return nil
}
func readLines2StringSlice(path string) []string {
content, err := ioutil.ReadFile(path)
func readLines2StringSlice(filepath string) []string {
if strings.HasPrefix(filepath, "http://") || strings.HasPrefix(filepath, "https://") {
home, _ := os.UserHomeDir()
dir := home + "/.paddleocr/rec/"
_, suffix := path.Split(filepath)
f := dir + suffix
if !isPathExist(f) {
err := downloadFile(f, filepath)
if err != nil {
log.Println("download ppocr key file error!")
return nil
}
}
filepath = f
}
content, err := ioutil.ReadFile(filepath)
if err != nil {
log.Println("read file error!")
log.Println("read ppocr key file error!")
return nil
}
lines := strings.Split(string(content), "\n")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册