From df98794b3e0ca765814809056fbf246e504ef6c5 Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Thu, 12 Jan 2017 11:27:40 -0800 Subject: [PATCH] Android: Use the Classifier interface in the activities. Change the member variable to be of type Classifier interface instead of the implementation (TensorFlowImageClassifier) in the activity class. My intention is to create new Classifier implementations that use the TensorFlow Java API (org.tensorflow.Graph, Session etc.) instead of the Android contrib API (org.tensorflow.contrib.android.TensorFlowInferenceInterface). This re-organization will make the switch between Classifier implementations easier during testing. While at it, some minor cleanups: - Get rid of unnecessary "throws IOException" - Use a factory function instead of an initializer function. Change: 144348772 --- .../tensorflow/demo/ClassifierActivity.java | 44 ++++++------- .../org/tensorflow/demo/DetectorActivity.java | 35 +++-------- .../demo/TensorFlowImageClassifier.java | 61 +++++++++++-------- .../demo/TensorFlowMultiBoxDetector.java | 48 ++++++++------- .../demo/TensorFlowYoloDetector.java | 37 +++++------ 5 files changed, 109 insertions(+), 116 deletions(-) diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java index d3fc67a0980..263751f03fd 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java @@ -63,7 +63,7 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab private static final boolean MAINTAIN_ASPECT = true; - private TensorFlowImageClassifier classifier; + private Classifier classifier; private Integer sensorOrientation; @@ -78,7 +78,6 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab private boolean computing = false; - private Matrix frameToCropTransform; private Matrix cropToFrameTransform; @@ -102,17 +101,15 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab @Override public void onPreviewSizeChosen(final Size size, final int rotation) { - final float textSizePx = TypedValue.applyDimension( - TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, - getResources().getDisplayMetrics()); + final float textSizePx = + TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); borderedText = new BorderedText(textSizePx); borderedText.setTypeface(Typeface.MONOSPACE); - classifier = new TensorFlowImageClassifier(); - try { - final int initStatus = - classifier.initializeTensorFlow( + classifier = + TensorFlowImageClassifier.create( getAssets(), MODEL_FILE, LABEL_FILE, @@ -122,10 +119,6 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab IMAGE_STD, INPUT_NAME, OUTPUT_NAME); - if (initStatus != 0) { - LOGGER.e("TF init status != 0: %d", initStatus); - throw new RuntimeException(); - } } catch (final Exception e) { throw new RuntimeException("Error initializing TensorFlow!", e); } @@ -137,8 +130,7 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab final Display display = getWindowManager().getDefaultDisplay(); final int screenOrientation = display.getRotation(); - LOGGER.i("Sensor orientation: %d, Screen orientation: %d", - rotation, screenOrientation); + LOGGER.i("Sensor orientation: %d, Screen orientation: %d", rotation, screenOrientation); sensorOrientation = rotation + screenOrientation; @@ -147,22 +139,24 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888); - frameToCropTransform = ImageUtils.getTransformationMatrix( - previewWidth, previewHeight, - INPUT_SIZE, INPUT_SIZE, - sensorOrientation, MAINTAIN_ASPECT); + frameToCropTransform = + ImageUtils.getTransformationMatrix( + previewWidth, previewHeight, + INPUT_SIZE, INPUT_SIZE, + sensorOrientation, MAINTAIN_ASPECT); cropToFrameTransform = new Matrix(); frameToCropTransform.invert(cropToFrameTransform); yuvBytes = new byte[3][]; - addCallback(new DrawCallback() { - @Override - public void drawCallback(final Canvas canvas) { - renderDebug(canvas); - } - }); + addCallback( + new DrawCallback() { + @Override + public void drawCallback(final Canvas canvas) { + renderDebug(canvas); + } + }); } @Override diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java index c8aeb8ae25c..9ab5a7108ab 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java @@ -124,30 +124,19 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable tracker = new MultiBoxTracker(getResources().getDisplayMetrics()); - if (USE_YOLO) { - final TensorFlowYoloDetector yoloDetector = new TensorFlowYoloDetector(); - try { - final int initStatus = - yoloDetector.initializeTensorFlow( + try { + if (USE_YOLO) { + detector = + TensorFlowYoloDetector.create( getAssets(), YOLO_MODEL_FILE, YOLO_INPUT_SIZE, YOLO_INPUT_NAME, YOLO_OUTPUT_NAMES, YOLO_BLOCK_SIZE); - if (initStatus != 0) { - LOGGER.e("TF init status != 0: %d", initStatus); - throw new RuntimeException(); - } - } catch (final Exception e) { - throw new RuntimeException("Error initializing TensorFlow!", e); - } - detector = yoloDetector; - } else { - final TensorFlowMultiBoxDetector multiBoxDetector = new TensorFlowMultiBoxDetector(); - try { - final int initStatus = - multiBoxDetector.initializeTensorFlow( + } else { + detector = + TensorFlowMultiBoxDetector.create( getAssets(), MB_MODEL_FILE, MB_LOCATION_FILE, @@ -157,14 +146,9 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable MB_IMAGE_STD, MB_INPUT_NAME, MB_OUTPUT_NAMES); - if (initStatus != 0) { - LOGGER.e("TF init status != 0: %d", initStatus); - throw new RuntimeException(); - } - } catch (final Exception e) { - throw new RuntimeException("Error initializing TensorFlow!", e); } - detector = multiBoxDetector; + } catch (final Exception e) { + throw new RuntimeException("Error initializing TensorFlow!", e); } previewWidth = size.getWidth(); @@ -249,6 +233,7 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable } OverlayView trackingOverlay; + @Override public void onImageAvailable(final ImageReader reader) { Image image = null; diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java index 6ea6cc27192..d1f69e8cc31 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java @@ -57,6 +57,8 @@ public class TensorFlowImageClassifier implements Classifier { private TensorFlowInferenceInterface inferenceInterface; + private TensorFlowImageClassifier() {} + /** * Initializes a native TensorFlow session for classifying images. * @@ -69,10 +71,9 @@ public class TensorFlowImageClassifier implements Classifier { * @param imageStd The assumed std of the image values. * @param inputName The label of the image input node. * @param outputName The label of the output node. - * @return The native return value, 0 indicating success. * @throws IOException */ - public int initializeTensorFlow( + public static Classifier create( AssetManager assetManager, String modelFilename, String labelFilename, @@ -81,9 +82,11 @@ public class TensorFlowImageClassifier implements Classifier { int imageMean, float imageStd, String inputName, - String outputName) throws IOException { - this.inputName = inputName; - this.outputName = outputName; + String outputName) + throws IOException { + TensorFlowImageClassifier c = new TensorFlowImageClassifier(); + c.inputName = inputName; + c.outputName = outputName; // Read the label names into memory. // TODO(andrewharp): make this handle non-assets. @@ -93,24 +96,29 @@ public class TensorFlowImageClassifier implements Classifier { br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename))); String line; while ((line = br.readLine()) != null) { - labels.add(line); + c.labels.add(line); } br.close(); - Log.i(TAG, "Read " + labels.size() + ", " + numClasses + " specified"); + Log.i(TAG, "Read " + c.labels.size() + ", " + numClasses + " specified"); - this.inputSize = inputSize; - this.imageMean = imageMean; - this.imageStd = imageStd; + c.inputSize = inputSize; + c.imageMean = imageMean; + c.imageStd = imageStd; // Pre-allocate buffers. - outputNames = new String[] {outputName}; - intValues = new int[inputSize * inputSize]; - floatValues = new float[inputSize * inputSize * 3]; - outputs = new float[numClasses]; + c.outputNames = new String[] {outputName}; + c.intValues = new int[inputSize * inputSize]; + c.floatValues = new float[inputSize * inputSize * 3]; + c.outputs = new float[numClasses]; - inferenceInterface = new TensorFlowInferenceInterface(); + c.inferenceInterface = new TensorFlowInferenceInterface(); - return inferenceInterface.initializeTensorFlow(assetManager, modelFilename); + final int status = c.inferenceInterface.initializeTensorFlow(assetManager, modelFilename); + if (status != 0) { + Log.e(TAG, "TF init status: " + status); + throw new RuntimeException("TF init status (" + status + ") != 0"); + } + return c; } @Override @@ -147,18 +155,19 @@ public class TensorFlowImageClassifier implements Classifier { Trace.endSection(); // Find the best classifications. - PriorityQueue pq = new PriorityQueue(3, - new Comparator() { - @Override - public int compare(Recognition lhs, Recognition rhs) { - // Intentionally reversed to put high confidence at the head of the queue. - return Float.compare(rhs.getConfidence(), lhs.getConfidence()); - } - }); + PriorityQueue pq = + new PriorityQueue( + 3, + new Comparator() { + @Override + public int compare(Recognition lhs, Recognition rhs) { + // Intentionally reversed to put high confidence at the head of the queue. + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); + } + }); for (int i = 0; i < outputs.length; ++i) { if (outputs[i] > THRESHOLD) { - pq.add(new Recognition( - "" + i, labels.get(i), outputs[i], null)); + pq.add(new Recognition("" + i, labels.get(i), outputs[i], null)); } } final ArrayList recognitions = new ArrayList(); diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java index 80b76051ffe..e438956c7dd 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java @@ -19,7 +19,6 @@ import android.content.res.AssetManager; import android.graphics.Bitmap; import android.graphics.RectF; import android.os.Trace; -import java.io.IOException; import java.util.ArrayList; import java.util.Comparator; import java.util.List; @@ -70,10 +69,8 @@ public class TensorFlowMultiBoxDetector implements Classifier { * @param imageStd The assumed std of the image values. * @param inputName The label of the image input node. * @param outputName The label of the output node. - * @return The native return value, 0 indicating success. - * @throws IOException */ - public int initializeTensorFlow( + public static Classifier create( final AssetManager assetManager, final String modelFilename, final String locationFilename, @@ -82,30 +79,37 @@ public class TensorFlowMultiBoxDetector implements Classifier { final int imageMean, final float imageStd, final String inputName, - final String outputName) - throws IOException { - this.inputName = inputName; - this.inputSize = inputSize; - this.imageMean = imageMean; - this.imageStd = imageStd; - this.numLocations = numLocations; + final String outputName) { + TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector(); + d.inputName = inputName; + d.inputSize = inputSize; + d.imageMean = imageMean; + d.imageStd = imageStd; + d.numLocations = numLocations; - this.boxPriors = new float[numLocations * 8]; + d.boxPriors = new float[numLocations * 8]; - loadCoderOptions(assetManager, locationFilename, boxPriors); + d.loadCoderOptions(assetManager, locationFilename, d.boxPriors); // Pre-allocate buffers. - outputNames = outputName.split(","); - intValues = new int[inputSize * inputSize]; - floatValues = new float[inputSize * inputSize * 3]; - outputScores = new float[numLocations]; - outputLocations = new float[numLocations * 4]; - - inferenceInterface = new TensorFlowInferenceInterface(); - - return inferenceInterface.initializeTensorFlow(assetManager, modelFilename); + d.outputNames = outputName.split(","); + d.intValues = new int[inputSize * inputSize]; + d.floatValues = new float[inputSize * inputSize * 3]; + d.outputScores = new float[numLocations]; + d.outputLocations = new float[numLocations * 4]; + + d.inferenceInterface = new TensorFlowInferenceInterface(); + + final int status = d.inferenceInterface.initializeTensorFlow(assetManager, modelFilename); + if (status != 0) { + LOGGER.e("TF init status: " + status); + throw new RuntimeException("TF init status (" + status + ") != 0"); + } + return d; } + private TensorFlowMultiBoxDetector() {} + // Load BoxCoderOptions from native code. private native void loadCoderOptions( AssetManager assetManager, String locationFilename, float[] boxPriors); diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java index b8dd11ba051..86c922b5891 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java @@ -19,7 +19,6 @@ import android.content.res.AssetManager; import android.graphics.Bitmap; import android.graphics.RectF; import android.os.Trace; -import java.io.IOException; import java.util.ArrayList; import java.util.Comparator; import java.util.List; @@ -89,34 +88,36 @@ public class TensorFlowYoloDetector implements Classifier { private TensorFlowInferenceInterface inferenceInterface; - /** - * Initializes a native TensorFlow session for classifying images. - * - * @throws IOException - */ - public int initializeTensorFlow( + /** Initializes a native TensorFlow session for classifying images. */ + public static Classifier create( final AssetManager assetManager, final String modelFilename, final int inputSize, final String inputName, final String outputName, - final int blockSize) - throws IOException { - this.inputName = inputName; - this.inputSize = inputSize; + final int blockSize) { + TensorFlowYoloDetector d = new TensorFlowYoloDetector(); + d.inputName = inputName; + d.inputSize = inputSize; // Pre-allocate buffers. - outputNames = outputName.split(","); - intValues = new int[inputSize * inputSize]; - floatValues = new float[inputSize * inputSize * 3]; - this.blockSize = blockSize; + d.outputNames = outputName.split(","); + d.intValues = new int[inputSize * inputSize]; + d.floatValues = new float[inputSize * inputSize * 3]; + d.blockSize = blockSize; - inferenceInterface = new TensorFlowInferenceInterface(); + d.inferenceInterface = new TensorFlowInferenceInterface(); - // Graphs must be converted from https://github.com/thtrieu/darkflow - return inferenceInterface.initializeTensorFlow(assetManager, modelFilename); + final int status = d.inferenceInterface.initializeTensorFlow(assetManager, modelFilename); + if (status != 0) { + LOGGER.e("TF init status: " + status); + throw new RuntimeException("TF init status (" + status + ") != 0"); + } + return d; } + private TensorFlowYoloDetector() {} + private float expit(final float x) { return (float) (1. / (1. + Math.exp(-x))); } -- GitLab