package com.baidu.paddle.lite.demo; import android.content.Context; import android.util.Log; import com.baidu.paddle.lite.*; import java.util.ArrayList; import java.util.Date; public class Predictor { private static final String TAG = Predictor.class.getSimpleName(); public boolean isLoaded = false; public int warmupIterNum = 0; public int inferIterNum = 1; protected Context appCtx = null; public int cpuThreadNum = 1; public String cpuPowerMode = "LITE_POWER_HIGH"; public String modelPath = ""; public String modelName = ""; protected PaddlePredictor paddlePredictor = null; protected float inferenceTime = 0; public Predictor() { } public boolean init(Context appCtx, String modelPath, int cpuThreadNum, String cpuPowerMode) { this.appCtx = appCtx; isLoaded = loadModel(modelPath, cpuThreadNum, cpuPowerMode); return isLoaded; } protected boolean loadModel(String modelPath, int cpuThreadNum, String cpuPowerMode) { // release model if exists releaseModel(); // load model if (modelPath.isEmpty()) { return false; } String realPath = modelPath; if (!modelPath.substring(0, 1).equals("/")) { // read model files from custom file_paths if the first character of mode file_paths is '/' // otherwise copy model to cache from assets realPath = appCtx.getCacheDir() + "/" + modelPath; Utils.copyDirectoryFromAssets(appCtx, modelPath, realPath); } if (realPath.isEmpty()) { return false; } MobileConfig config = new MobileConfig(); config.setModelDir(realPath); config.setThreads(cpuThreadNum); if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_HIGH")) { config.setPowerMode(PowerMode.LITE_POWER_HIGH); } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_LOW")) { config.setPowerMode(PowerMode.LITE_POWER_LOW); } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_FULL")) { config.setPowerMode(PowerMode.LITE_POWER_FULL); } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_NO_BIND")) { config.setPowerMode(PowerMode.LITE_POWER_NO_BIND); } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_RAND_HIGH")) { config.setPowerMode(PowerMode.LITE_POWER_RAND_HIGH); } else if (cpuPowerMode.equalsIgnoreCase("LITE_POWER_RAND_LOW")) { config.setPowerMode(PowerMode.LITE_POWER_RAND_LOW); } else { Log.e(TAG, "unknown cpu power mode!"); return false; } paddlePredictor = PaddlePredictor.createPaddlePredictor(config); this.cpuThreadNum = cpuThreadNum; this.cpuPowerMode = cpuPowerMode; this.modelPath = realPath; this.modelName = realPath.substring(realPath.lastIndexOf("/") + 1); return true; } public void releaseModel() { paddlePredictor = null; isLoaded = false; cpuThreadNum = 1; cpuPowerMode = "LITE_POWER_HIGH"; modelPath = ""; modelName = ""; } public Tensor getInput(int idx) { if (!isLoaded()) { return null; } return paddlePredictor.getInput(idx); } public Tensor getOutput(int idx) { if (!isLoaded()) { return null; } return paddlePredictor.getOutput(idx); } public boolean runModel() { if (!isLoaded()) { return false; } // warm up for (int i = 0; i < warmupIterNum; i++){ paddlePredictor.run(); } // inference Date start = new Date(); for (int i = 0; i < inferIterNum; i++) { paddlePredictor.run(); } Date end = new Date(); inferenceTime = (end.getTime() - start.getTime()) / (float) inferIterNum; return true; } public boolean isLoaded() { return paddlePredictor != null && isLoaded; } public String modelPath() { return modelPath; } public String modelName() { return modelName; } public int cpuThreadNum() { return cpuThreadNum; } public String cpuPowerMode() { return cpuPowerMode; } public float inferenceTime() { return inferenceTime; } }