提交 16a2ab58 编写于 作者: L LKKlein

fix default download path; auto-download dict from http

上级 62d302ff
...@@ -261,7 +261,7 @@ func (ocr *OCRSystem) StartServer(port string) ...@@ -261,7 +261,7 @@ func (ocr *OCRSystem) StartServer(port string)
当前给定的配置文件`config/conf.yaml`中,包含了默认的OCR预测配置参数,可根据个人需要更改相关参数。 当前给定的配置文件`config/conf.yaml`中,包含了默认的OCR预测配置参数,可根据个人需要更改相关参数。
比如,将`use_gpu`改为`false`,使用CPU执行预测;将`det_model_dir`, `rec_model_dir`, `cls_model_dir`都更改为自己的本地模型路径,也或者是更改字典`rec_char_dict_path`的路径配置参数包含了预测引擎、检测模型、检测阈值、方向分类模型、识别模型及阈值的相关参数,具体参数的意义可参见[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/whl.md#%E5%8F%82%E6%95%B0%E8%AF%B4%E6%98%8E) 比如,将`use_gpu`改为`false`,使用CPU执行预测;将`det_model_dir`, `rec_model_dir`, `cls_model_dir`都更改为自己的本地模型路径,也或者是更改字典`rec_char_dict_path`的路径,这四个路径如果配置http链接,会自动下载到本地目录。另外,配置参数包含了预测引擎、检测模型、检测阈值、方向分类模型、识别模型及阈值的相关参数,具体参数的意义可参见[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/whl.md#%E5%8F%82%E6%95%B0%E8%AF%B4%E6%98%8E)
### 3.2 编译预测demo ### 3.2 编译预测demo
......
...@@ -30,7 +30,7 @@ rec_image_shape: [3, 32, 320] ...@@ -30,7 +30,7 @@ rec_image_shape: [3, 32, 320]
rec_char_type: "ch" rec_char_type: "ch"
rec_batch_num: 30 rec_batch_num: 30
max_text_length: 25 max_text_length: 25
rec_char_dict_path: "config/ppocr_keys_v1.txt" rec_char_dict_path: "https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/develop/ppocr/utils/ppocr_keys_v1.txt"
use_space_char: true use_space_char: true
# params for text classifier # params for text classifier
......
...@@ -28,7 +28,7 @@ var ( ...@@ -28,7 +28,7 @@ var (
"rec_char_type": "ch", "rec_char_type": "ch",
"rec_batch_num": 30, "rec_batch_num": 30,
"max_text_length": 25, "max_text_length": 25,
"rec_char_dict_path": "config/ppocr_keys_v1.txt", "rec_char_dict_path": "https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/develop/ppocr/utils/ppocr_keys_v1.txt",
"use_space_char": true, "use_space_char": true,
"use_angle_cls": false, "use_angle_cls": false,
......
...@@ -2,6 +2,7 @@ package ocr ...@@ -2,6 +2,7 @@ package ocr
import ( import (
"log" "log"
"os"
"time" "time"
"github.com/LKKlein/gocv" "github.com/LKKlein/gocv"
...@@ -34,7 +35,8 @@ func NewTextClassifier(modelDir string, args map[string]interface{}) *TextClassi ...@@ -34,7 +35,8 @@ func NewTextClassifier(modelDir string, args map[string]interface{}) *TextClassi
shape: shapes, shape: shapes,
} }
if checkModelExists(modelDir) { if checkModelExists(modelDir) {
modelDir, _ = downloadModel("./inference/cls", modelDir) home, _ := os.UserHomeDir()
modelDir, _ = downloadModel(home+"/.paddleocr/cls", modelDir)
} else { } else {
log.Panicf("cls model path: %v not exist! Please check!", modelDir) log.Panicf("cls model path: %v not exist! Please check!", modelDir)
} }
......
...@@ -2,6 +2,7 @@ package ocr ...@@ -2,6 +2,7 @@ package ocr
import ( import (
"log" "log"
"os"
"time" "time"
"github.com/LKKlein/gocv" "github.com/LKKlein/gocv"
...@@ -25,7 +26,8 @@ func NewDBDetector(modelDir string, args map[string]interface{}) *DBDetector { ...@@ -25,7 +26,8 @@ func NewDBDetector(modelDir string, args map[string]interface{}) *DBDetector {
postProcess: NewDBPostProcess(thresh, boxThresh, unClipRatio), postProcess: NewDBPostProcess(thresh, boxThresh, unClipRatio),
} }
if checkModelExists(modelDir) { if checkModelExists(modelDir) {
modelDir, _ = downloadModel("./inference/det", modelDir) home, _ := os.UserHomeDir()
modelDir, _ = downloadModel(home+"/.paddleocr/det", modelDir)
} else { } else {
log.Panicf("det model path: %v not exist! Please check!", modelDir) log.Panicf("det model path: %v not exist! Please check!", modelDir)
} }
......
...@@ -2,6 +2,7 @@ package ocr ...@@ -2,6 +2,7 @@ package ocr
import ( import (
"log" "log"
"os"
"time" "time"
"github.com/LKKlein/gocv" "github.com/LKKlein/gocv"
...@@ -37,7 +38,8 @@ func NewTextRecognizer(modelDir string, args map[string]interface{}) *TextRecogn ...@@ -37,7 +38,8 @@ func NewTextRecognizer(modelDir string, args map[string]interface{}) *TextRecogn
labels: labels, labels: labels,
} }
if checkModelExists(modelDir) { if checkModelExists(modelDir) {
modelDir, _ = downloadModel("./inference/rec/ch", modelDir) home, _ := os.UserHomeDir()
modelDir, _ = downloadModel(home+"/.paddleocr/rec/ch", modelDir)
} else { } else {
log.Panicf("rec model path: %v not exist! Please check!", modelDir) log.Panicf("rec model path: %v not exist! Please check!", modelDir)
} }
......
...@@ -9,7 +9,6 @@ import ( ...@@ -9,7 +9,6 @@ import (
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"regexp"
"strings" "strings"
"github.com/LKKlein/gocv" "github.com/LKKlein/gocv"
...@@ -154,8 +153,10 @@ func isPathExist(path string) bool { ...@@ -154,8 +153,10 @@ func isPathExist(path string) bool {
func downloadModel(modelDir, modelPath string) (string, error) { func downloadModel(modelDir, modelPath string) (string, error) {
if modelPath != "" && (strings.HasPrefix(modelPath, "http://") || if modelPath != "" && (strings.HasPrefix(modelPath, "http://") ||
strings.HasPrefix(modelPath, "ftp://") || strings.HasPrefix(modelPath, "https://")) { strings.HasPrefix(modelPath, "ftp://") || strings.HasPrefix(modelPath, "https://")) {
reg := regexp.MustCompile("^(http|https|ftp)://[^/]+/(.+)") if checkModelExists(modelDir) {
suffix := reg.FindStringSubmatch(modelPath)[2] return modelDir, nil
}
_, suffix := path.Split(modelPath)
outPath := filepath.Join(modelDir, suffix) outPath := filepath.Join(modelDir, suffix)
outDir := filepath.Dir(outPath) outDir := filepath.Dir(outPath)
if !isPathExist(outDir) { if !isPathExist(outDir) {
...@@ -168,16 +169,13 @@ func downloadModel(modelDir, modelPath string) (string, error) { ...@@ -168,16 +169,13 @@ func downloadModel(modelDir, modelPath string) (string, error) {
return "", err return "", err
} }
} }
if strings.HasSuffix(outPath, ".tar") {
_, f := path.Split(suffix) if strings.HasSuffix(outPath, ".tar") && !checkModelExists(modelDir) {
nextDir := strings.TrimSuffix(f, ".tar") unTar(modelDir, outPath)
finalPath := modelDir + "/" + nextDir os.Remove(outPath)
if !checkModelExists(finalPath) { return modelDir, nil
unTar(modelDir, outPath)
}
return finalPath, nil
} }
return outPath, nil return modelDir, nil
} }
return modelPath, nil return modelPath, nil
} }
...@@ -202,15 +200,16 @@ func unTar(dst, src string) (err error) { ...@@ -202,15 +200,16 @@ func unTar(dst, src string) (err error) {
continue continue
} }
dstFileDir := filepath.Join(dst, hdr.Name) var dstFileDir string
if strings.Contains(hdr.Name, "model") {
dstFileDir = filepath.Join(dst, "model")
} else if strings.Contains(hdr.Name, "params") {
dstFileDir = filepath.Join(dst, "params")
}
switch hdr.Typeflag { switch hdr.Typeflag {
case tar.TypeDir: case tar.TypeDir:
if b := isPathExist(dstFileDir); !b { continue
if err := os.MkdirAll(dstFileDir, 0775); err != nil {
return err
}
}
case tar.TypeReg: case tar.TypeReg:
file, err := os.OpenFile(dstFileDir, os.O_CREATE|os.O_RDWR, os.FileMode(hdr.Mode)) file, err := os.OpenFile(dstFileDir, os.O_CREATE|os.O_RDWR, os.FileMode(hdr.Mode))
if err != nil { if err != nil {
...@@ -227,10 +226,24 @@ func unTar(dst, src string) (err error) { ...@@ -227,10 +226,24 @@ func unTar(dst, src string) (err error) {
return nil return nil
} }
func readLines2StringSlice(path string) []string { func readLines2StringSlice(filepath string) []string {
content, err := ioutil.ReadFile(path) if strings.HasPrefix(filepath, "http://") || strings.HasPrefix(filepath, "https://") {
home, _ := os.UserHomeDir()
dir := home + "/.paddleocr/rec/"
_, suffix := path.Split(filepath)
f := dir + suffix
if !isPathExist(f) {
err := downloadFile(f, filepath)
if err != nil {
log.Println("download ppocr key file error!")
return nil
}
}
filepath = f
}
content, err := ioutil.ReadFile(filepath)
if err != nil { if err != nil {
log.Println("read file error!") log.Println("read ppocr key file error!")
return nil return nil
} }
lines := strings.Split(string(content), "\n") lines := strings.Split(string(content), "\n")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册