ocr_det.go 1.5 KB
Newer Older
L
LKKlein 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
package ocr

import (
	"log"
	"time"

	"github.com/LKKlein/gocv"
)

type DBDetector struct {
	*PaddleModel
	preProcess  DetPreProcess
	postProcess DetPostProcess
}

func NewDBDetector(modelDir string, args map[string]interface{}) *DBDetector {
	maxSideLen := getInt(args, "det_max_side_len", 960)
	thresh := getFloat64(args, "det_db_thresh", 0.3)
	boxThresh := getFloat64(args, "det_db_box_thresh", 0.5)
	unClipRatio := getFloat64(args, "det_db_unclip_ratio", 2.0)

	detector := &DBDetector{
		PaddleModel: NewPaddleModel(args),
		preProcess:  NewDBProcess(make([]int, 0), maxSideLen),
		postProcess: NewDBPostProcess(thresh, boxThresh, unClipRatio),
	}
	if checkModelExists(modelDir) {
		modelDir, _ = downloadModel("./inference/det", modelDir)
	} else {
		log.Panicf("det model path: %v not exist! Please check!", modelDir)
	}
	detector.LoadModel(modelDir)
	return detector
}

func (det *DBDetector) Run(img gocv.Mat) [][][]int {
	oriH := img.Rows()
	oriW := img.Cols()
	data, resizeH, resizeW := det.preProcess.Run(img)
	st := time.Now()
	det.input.SetValue(data)
	det.input.Reshape([]int32{1, 3, int32(resizeH), int32(resizeW)})

	det.predictor.SetZeroCopyInput(det.input)
	det.predictor.ZeroCopyRun()
	det.predictor.GetZeroCopyOutput(det.outputs[0])

	ratioH, ratioW := float64(resizeH)/float64(oriH), float64(resizeW)/float64(oriW)
	boxes := det.postProcess.Run(det.outputs[0], oriH, oriW, ratioH, ratioW)
	log.Println("det_box num: ", len(boxes), ", time elapse: ", time.Since(st))
	return boxes
}