package com.baidu.paddle.lite.demo.ocr; import android.content.Context; import android.graphics.Bitmap; import android.graphics.Canvas; import android.graphics.Color; import android.graphics.Paint; import android.graphics.Path; import android.graphics.Point; import android.util.Log; import java.io.File; import java.io.InputStream; import java.util.ArrayList; import java.util.Date; import java.util.List; import java.util.Vector; import static android.graphics.Color.*; public class Predictor { private static final String TAG = Predictor.class.getSimpleName(); public boolean isLoaded = false; public int warmupIterNum = 1; public int inferIterNum = 1; public int cpuThreadNum = 4; public String cpuPowerMode = "LITE_POWER_HIGH"; public String modelPath = ""; public String modelName = ""; protected OCRPredictorNative paddlePredictor = null; protected float inferenceTime = 0; // Only for object detection protected Vector wordLabels = new Vector(); protected String inputColorFormat = "BGR"; protected long[] inputShape = new long[]{1, 3, 960}; protected float[] inputMean = new float[]{0.485f, 0.456f, 0.406f}; protected float[] inputStd = new float[]{1.0f / 0.229f, 1.0f / 0.224f, 1.0f / 0.225f}; protected float scoreThreshold = 0.1f; protected Bitmap inputImage = null; protected Bitmap outputImage = null; protected volatile String outputResult = ""; protected float preprocessTime = 0; protected float postprocessTime = 0; public Predictor() { } public boolean init(Context appCtx, String modelPath, String labelPath) { isLoaded = loadModel(appCtx, modelPath, cpuThreadNum, cpuPowerMode); if (!isLoaded) { return false; } isLoaded = loadLabel(appCtx, labelPath); return isLoaded; } public boolean init(Context appCtx, String modelPath, String labelPath, int cpuThreadNum, String cpuPowerMode, String inputColorFormat, long[] inputShape, float[] inputMean, float[] inputStd, float scoreThreshold) { if (inputShape.length != 3) { Log.e(TAG, "Size of input shape should be: 3"); return false; } if (inputMean.length != inputShape[1]) { Log.e(TAG, "Size of input mean should be: " + Long.toString(inputShape[1])); return false; } if (inputStd.length != inputShape[1]) { Log.e(TAG, "Size of input std should be: " + Long.toString(inputShape[1])); return false; } if (inputShape[0] != 1) { Log.e(TAG, "Only one batch is supported in the image classification demo, you can use any batch size in " + "your Apps!"); return false; } if (inputShape[1] != 1 && inputShape[1] != 3) { Log.e(TAG, "Only one/three channels are supported in the image classification demo, you can use any " + "channel size in your Apps!"); return false; } if (!inputColorFormat.equalsIgnoreCase("BGR")) { Log.e(TAG, "Only BGR color format is supported."); return false; } boolean isLoaded = init(appCtx, modelPath, labelPath); if (!isLoaded) { return false; } this.inputColorFormat = inputColorFormat; this.inputShape = inputShape; this.inputMean = inputMean; this.inputStd = inputStd; this.scoreThreshold = scoreThreshold; return true; } protected boolean loadModel(Context appCtx, String modelPath, int cpuThreadNum, String cpuPowerMode) { // Release model if exists releaseModel(); // Load model if (modelPath.isEmpty()) { return false; } String realPath = modelPath; if (!modelPath.substring(0, 1).equals("/")) { // Read model files from custom path if the first character of mode path is '/' // otherwise copy model to cache from assets realPath = appCtx.getCacheDir() + "/" + modelPath; Utils.copyDirectoryFromAssets(appCtx, modelPath, realPath); } if (realPath.isEmpty()) { return false; } OCRPredictorNative.Config config = new OCRPredictorNative.Config(); config.cpuThreadNum = cpuThreadNum; config.detModelFilename = realPath + File.separator + "ch_ppocr_mobile_v2.0_det_opt.nb"; config.recModelFilename = realPath + File.separator + "ch_ppocr_mobile_v2.0_rec_opt.nb"; config.clsModelFilename = realPath + File.separator + "ch_ppocr_mobile_v2.0_cls_opt.nb"; Log.e("Predictor", "model path" + config.detModelFilename + " ; " + config.recModelFilename + ";" + config.clsModelFilename); config.cpuPower = cpuPowerMode; paddlePredictor = new OCRPredictorNative(config); this.cpuThreadNum = cpuThreadNum; this.cpuPowerMode = cpuPowerMode; this.modelPath = realPath; this.modelName = realPath.substring(realPath.lastIndexOf("/") + 1); return true; } public void releaseModel() { if (paddlePredictor != null) { paddlePredictor.destory(); paddlePredictor = null; } isLoaded = false; cpuThreadNum = 1; cpuPowerMode = "LITE_POWER_HIGH"; modelPath = ""; modelName = ""; } protected boolean loadLabel(Context appCtx, String labelPath) { wordLabels.clear(); // Load word labels from file try { InputStream assetsInputStream = appCtx.getAssets().open(labelPath); int available = assetsInputStream.available(); byte[] lines = new byte[available]; assetsInputStream.read(lines); assetsInputStream.close(); String words = new String(lines); String[] contents = words.split("\n"); for (String content : contents) { wordLabels.add(content); } Log.i(TAG, "Word label size: " + wordLabels.size()); } catch (Exception e) { Log.e(TAG, e.getMessage()); return false; } return true; } public boolean runModel() { if (inputImage == null || !isLoaded()) { return false; } // Pre-process image, and feed input tensor with pre-processed data Bitmap scaleImage = Utils.resizeWithStep(inputImage, Long.valueOf(inputShape[2]).intValue(), 32); Date start = new Date(); int channels = (int) inputShape[1]; int width = scaleImage.getWidth(); int height = scaleImage.getHeight(); float[] inputData = new float[channels * width * height]; if (channels == 3) { int[] channelIdx = null; if (inputColorFormat.equalsIgnoreCase("RGB")) { channelIdx = new int[]{0, 1, 2}; } else if (inputColorFormat.equalsIgnoreCase("BGR")) { channelIdx = new int[]{2, 1, 0}; } else { Log.i(TAG, "Unknown color format " + inputColorFormat + ", only RGB and BGR color format is " + "supported!"); return false; } int[] channelStride = new int[]{width * height, width * height * 2}; int p = scaleImage.getPixel(scaleImage.getWidth() - 1, scaleImage.getHeight() - 1); for (int y = 0; y < height; y++) { for (int x = 0; x < width; x++) { int color = scaleImage.getPixel(x, y); float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f, (float) blue(color) / 255.0f}; inputData[y * width + x] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0]; inputData[y * width + x + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1]; inputData[y * width + x + channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2]; } } } else if (channels == 1) { for (int y = 0; y < height; y++) { for (int x = 0; x < width; x++) { int color = inputImage.getPixel(x, y); float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f; inputData[y * width + x] = (gray - inputMean[0]) / inputStd[0]; } } } else { Log.i(TAG, "Unsupported channel size " + Integer.toString(channels) + ", only channel 1 and 3 is " + "supported!"); return false; } float[] pixels = inputData; Log.i(TAG, "pixels " + pixels[0] + " " + pixels[1] + " " + pixels[2] + " " + pixels[3] + " " + pixels[pixels.length / 2] + " " + pixels[pixels.length / 2 + 1] + " " + pixels[pixels.length - 2] + " " + pixels[pixels.length - 1]); Date end = new Date(); preprocessTime = (float) (end.getTime() - start.getTime()); // Warm up for (int i = 0; i < warmupIterNum; i++) { paddlePredictor.runImage(inputData, width, height, channels, inputImage); } warmupIterNum = 0; // do not need warm // Run inference start = new Date(); ArrayList results = paddlePredictor.runImage(inputData, width, height, channels, inputImage); end = new Date(); inferenceTime = (end.getTime() - start.getTime()) / (float) inferIterNum; results = postprocess(results); Log.i(TAG, "[stat] Preprocess Time: " + preprocessTime + " ; Inference Time: " + inferenceTime + " ;Box Size " + results.size()); drawResults(results); return true; } public boolean isLoaded() { return paddlePredictor != null && isLoaded; } public String modelPath() { return modelPath; } public String modelName() { return modelName; } public int cpuThreadNum() { return cpuThreadNum; } public String cpuPowerMode() { return cpuPowerMode; } public float inferenceTime() { return inferenceTime; } public Bitmap inputImage() { return inputImage; } public Bitmap outputImage() { return outputImage; } public String outputResult() { return outputResult; } public float preprocessTime() { return preprocessTime; } public float postprocessTime() { return postprocessTime; } public void setInputImage(Bitmap image) { if (image == null) { return; } this.inputImage = image.copy(Bitmap.Config.ARGB_8888, true); } private ArrayList postprocess(ArrayList results) { for (OcrResultModel r : results) { StringBuffer word = new StringBuffer(); for (int index : r.getWordIndex()) { if (index >= 0 && index < wordLabels.size()) { word.append(wordLabels.get(index)); } else { Log.e(TAG, "Word index is not in label list:" + index); word.append("×"); } } r.setLabel(word.toString()); } return results; } private void drawResults(ArrayList results) { StringBuffer outputResultSb = new StringBuffer(""); for (int i = 0; i < results.size(); i++) { OcrResultModel result = results.get(i); StringBuilder sb = new StringBuilder(""); sb.append(result.getLabel()); sb.append(" ").append(result.getConfidence()); sb.append("; Points: "); for (Point p : result.getPoints()) { sb.append("(").append(p.x).append(",").append(p.y).append(") "); } Log.i(TAG, sb.toString()); // show LOG in Logcat panel outputResultSb.append(i + 1).append(": ").append(result.getLabel()).append("\n"); } outputResult = outputResultSb.toString(); outputImage = inputImage; Canvas canvas = new Canvas(outputImage); Paint paintFillAlpha = new Paint(); paintFillAlpha.setStyle(Paint.Style.FILL); paintFillAlpha.setColor(Color.parseColor("#3B85F5")); paintFillAlpha.setAlpha(50); Paint paint = new Paint(); paint.setColor(Color.parseColor("#3B85F5")); paint.setStrokeWidth(5); paint.setStyle(Paint.Style.STROKE); for (OcrResultModel result : results) { Path path = new Path(); List points = result.getPoints(); path.moveTo(points.get(0).x, points.get(0).y); for (int i = points.size() - 1; i >= 0; i--) { Point p = points.get(i); path.lineTo(p.x, p.y); } canvas.drawPath(path, paint); canvas.drawPath(path, paintFillAlpha); } } }