提交 df98794b 编写于 作者: A Asim Shankar 提交者: TensorFlower Gardener

Android: Use the Classifier interface in the activities.

Change the member variable to be of type Classifier interface
instead of the implementation (TensorFlowImageClassifier) in the
activity class.

My intention is to create new Classifier implementations that use
the TensorFlow Java API (org.tensorflow.Graph, Session etc.) instead
of the Android contrib API
(org.tensorflow.contrib.android.TensorFlowInferenceInterface). This
re-organization will make the switch between Classifier implementations
easier during testing.

While at it, some minor cleanups:
- Get rid of unnecessary "throws IOException"
- Use a factory function instead of an initializer function.
Change: 144348772
上级 1f409de8
......@@ -63,7 +63,7 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab
private static final boolean MAINTAIN_ASPECT = true;
private TensorFlowImageClassifier classifier;
private Classifier classifier;
private Integer sensorOrientation;
......@@ -78,7 +78,6 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab
private boolean computing = false;
private Matrix frameToCropTransform;
private Matrix cropToFrameTransform;
......@@ -102,17 +101,15 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab
@Override
public void onPreviewSizeChosen(final Size size, final int rotation) {
final float textSizePx = TypedValue.applyDimension(
TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP,
getResources().getDisplayMetrics());
final float textSizePx =
TypedValue.applyDimension(
TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);
classifier = new TensorFlowImageClassifier();
try {
final int initStatus =
classifier.initializeTensorFlow(
classifier =
TensorFlowImageClassifier.create(
getAssets(),
MODEL_FILE,
LABEL_FILE,
......@@ -122,10 +119,6 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab
IMAGE_STD,
INPUT_NAME,
OUTPUT_NAME);
if (initStatus != 0) {
LOGGER.e("TF init status != 0: %d", initStatus);
throw new RuntimeException();
}
} catch (final Exception e) {
throw new RuntimeException("Error initializing TensorFlow!", e);
}
......@@ -137,8 +130,7 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab
final Display display = getWindowManager().getDefaultDisplay();
final int screenOrientation = display.getRotation();
LOGGER.i("Sensor orientation: %d, Screen orientation: %d",
rotation, screenOrientation);
LOGGER.i("Sensor orientation: %d, Screen orientation: %d", rotation, screenOrientation);
sensorOrientation = rotation + screenOrientation;
......@@ -147,22 +139,24 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888);
frameToCropTransform = ImageUtils.getTransformationMatrix(
previewWidth, previewHeight,
INPUT_SIZE, INPUT_SIZE,
sensorOrientation, MAINTAIN_ASPECT);
frameToCropTransform =
ImageUtils.getTransformationMatrix(
previewWidth, previewHeight,
INPUT_SIZE, INPUT_SIZE,
sensorOrientation, MAINTAIN_ASPECT);
cropToFrameTransform = new Matrix();
frameToCropTransform.invert(cropToFrameTransform);
yuvBytes = new byte[3][];
addCallback(new DrawCallback() {
@Override
public void drawCallback(final Canvas canvas) {
renderDebug(canvas);
}
});
addCallback(
new DrawCallback() {
@Override
public void drawCallback(final Canvas canvas) {
renderDebug(canvas);
}
});
}
@Override
......
......@@ -124,30 +124,19 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
tracker = new MultiBoxTracker(getResources().getDisplayMetrics());
if (USE_YOLO) {
final TensorFlowYoloDetector yoloDetector = new TensorFlowYoloDetector();
try {
final int initStatus =
yoloDetector.initializeTensorFlow(
try {
if (USE_YOLO) {
detector =
TensorFlowYoloDetector.create(
getAssets(),
YOLO_MODEL_FILE,
YOLO_INPUT_SIZE,
YOLO_INPUT_NAME,
YOLO_OUTPUT_NAMES,
YOLO_BLOCK_SIZE);
if (initStatus != 0) {
LOGGER.e("TF init status != 0: %d", initStatus);
throw new RuntimeException();
}
} catch (final Exception e) {
throw new RuntimeException("Error initializing TensorFlow!", e);
}
detector = yoloDetector;
} else {
final TensorFlowMultiBoxDetector multiBoxDetector = new TensorFlowMultiBoxDetector();
try {
final int initStatus =
multiBoxDetector.initializeTensorFlow(
} else {
detector =
TensorFlowMultiBoxDetector.create(
getAssets(),
MB_MODEL_FILE,
MB_LOCATION_FILE,
......@@ -157,14 +146,9 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
MB_IMAGE_STD,
MB_INPUT_NAME,
MB_OUTPUT_NAMES);
if (initStatus != 0) {
LOGGER.e("TF init status != 0: %d", initStatus);
throw new RuntimeException();
}
} catch (final Exception e) {
throw new RuntimeException("Error initializing TensorFlow!", e);
}
detector = multiBoxDetector;
} catch (final Exception e) {
throw new RuntimeException("Error initializing TensorFlow!", e);
}
previewWidth = size.getWidth();
......@@ -249,6 +233,7 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
}
OverlayView trackingOverlay;
@Override
public void onImageAvailable(final ImageReader reader) {
Image image = null;
......
......@@ -57,6 +57,8 @@ public class TensorFlowImageClassifier implements Classifier {
private TensorFlowInferenceInterface inferenceInterface;
private TensorFlowImageClassifier() {}
/**
* Initializes a native TensorFlow session for classifying images.
*
......@@ -69,10 +71,9 @@ public class TensorFlowImageClassifier implements Classifier {
* @param imageStd The assumed std of the image values.
* @param inputName The label of the image input node.
* @param outputName The label of the output node.
* @return The native return value, 0 indicating success.
* @throws IOException
*/
public int initializeTensorFlow(
public static Classifier create(
AssetManager assetManager,
String modelFilename,
String labelFilename,
......@@ -81,9 +82,11 @@ public class TensorFlowImageClassifier implements Classifier {
int imageMean,
float imageStd,
String inputName,
String outputName) throws IOException {
this.inputName = inputName;
this.outputName = outputName;
String outputName)
throws IOException {
TensorFlowImageClassifier c = new TensorFlowImageClassifier();
c.inputName = inputName;
c.outputName = outputName;
// Read the label names into memory.
// TODO(andrewharp): make this handle non-assets.
......@@ -93,24 +96,29 @@ public class TensorFlowImageClassifier implements Classifier {
br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
String line;
while ((line = br.readLine()) != null) {
labels.add(line);
c.labels.add(line);
}
br.close();
Log.i(TAG, "Read " + labels.size() + ", " + numClasses + " specified");
Log.i(TAG, "Read " + c.labels.size() + ", " + numClasses + " specified");
this.inputSize = inputSize;
this.imageMean = imageMean;
this.imageStd = imageStd;
c.inputSize = inputSize;
c.imageMean = imageMean;
c.imageStd = imageStd;
// Pre-allocate buffers.
outputNames = new String[] {outputName};
intValues = new int[inputSize * inputSize];
floatValues = new float[inputSize * inputSize * 3];
outputs = new float[numClasses];
c.outputNames = new String[] {outputName};
c.intValues = new int[inputSize * inputSize];
c.floatValues = new float[inputSize * inputSize * 3];
c.outputs = new float[numClasses];
inferenceInterface = new TensorFlowInferenceInterface();
c.inferenceInterface = new TensorFlowInferenceInterface();
return inferenceInterface.initializeTensorFlow(assetManager, modelFilename);
final int status = c.inferenceInterface.initializeTensorFlow(assetManager, modelFilename);
if (status != 0) {
Log.e(TAG, "TF init status: " + status);
throw new RuntimeException("TF init status (" + status + ") != 0");
}
return c;
}
@Override
......@@ -147,18 +155,19 @@ public class TensorFlowImageClassifier implements Classifier {
Trace.endSection();
// Find the best classifications.
PriorityQueue<Recognition> pq = new PriorityQueue<Recognition>(3,
new Comparator<Recognition>() {
@Override
public int compare(Recognition lhs, Recognition rhs) {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
PriorityQueue<Recognition> pq =
new PriorityQueue<Recognition>(
3,
new Comparator<Recognition>() {
@Override
public int compare(Recognition lhs, Recognition rhs) {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
for (int i = 0; i < outputs.length; ++i) {
if (outputs[i] > THRESHOLD) {
pq.add(new Recognition(
"" + i, labels.get(i), outputs[i], null));
pq.add(new Recognition("" + i, labels.get(i), outputs[i], null));
}
}
final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
......
......@@ -19,7 +19,6 @@ import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.RectF;
import android.os.Trace;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
......@@ -70,10 +69,8 @@ public class TensorFlowMultiBoxDetector implements Classifier {
* @param imageStd The assumed std of the image values.
* @param inputName The label of the image input node.
* @param outputName The label of the output node.
* @return The native return value, 0 indicating success.
* @throws IOException
*/
public int initializeTensorFlow(
public static Classifier create(
final AssetManager assetManager,
final String modelFilename,
final String locationFilename,
......@@ -82,30 +79,37 @@ public class TensorFlowMultiBoxDetector implements Classifier {
final int imageMean,
final float imageStd,
final String inputName,
final String outputName)
throws IOException {
this.inputName = inputName;
this.inputSize = inputSize;
this.imageMean = imageMean;
this.imageStd = imageStd;
this.numLocations = numLocations;
final String outputName) {
TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector();
d.inputName = inputName;
d.inputSize = inputSize;
d.imageMean = imageMean;
d.imageStd = imageStd;
d.numLocations = numLocations;
this.boxPriors = new float[numLocations * 8];
d.boxPriors = new float[numLocations * 8];
loadCoderOptions(assetManager, locationFilename, boxPriors);
d.loadCoderOptions(assetManager, locationFilename, d.boxPriors);
// Pre-allocate buffers.
outputNames = outputName.split(",");
intValues = new int[inputSize * inputSize];
floatValues = new float[inputSize * inputSize * 3];
outputScores = new float[numLocations];
outputLocations = new float[numLocations * 4];
inferenceInterface = new TensorFlowInferenceInterface();
return inferenceInterface.initializeTensorFlow(assetManager, modelFilename);
d.outputNames = outputName.split(",");
d.intValues = new int[inputSize * inputSize];
d.floatValues = new float[inputSize * inputSize * 3];
d.outputScores = new float[numLocations];
d.outputLocations = new float[numLocations * 4];
d.inferenceInterface = new TensorFlowInferenceInterface();
final int status = d.inferenceInterface.initializeTensorFlow(assetManager, modelFilename);
if (status != 0) {
LOGGER.e("TF init status: " + status);
throw new RuntimeException("TF init status (" + status + ") != 0");
}
return d;
}
private TensorFlowMultiBoxDetector() {}
// Load BoxCoderOptions from native code.
private native void loadCoderOptions(
AssetManager assetManager, String locationFilename, float[] boxPriors);
......
......@@ -19,7 +19,6 @@ import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.RectF;
import android.os.Trace;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
......@@ -89,34 +88,36 @@ public class TensorFlowYoloDetector implements Classifier {
private TensorFlowInferenceInterface inferenceInterface;
/**
* Initializes a native TensorFlow session for classifying images.
*
* @throws IOException
*/
public int initializeTensorFlow(
/** Initializes a native TensorFlow session for classifying images. */
public static Classifier create(
final AssetManager assetManager,
final String modelFilename,
final int inputSize,
final String inputName,
final String outputName,
final int blockSize)
throws IOException {
this.inputName = inputName;
this.inputSize = inputSize;
final int blockSize) {
TensorFlowYoloDetector d = new TensorFlowYoloDetector();
d.inputName = inputName;
d.inputSize = inputSize;
// Pre-allocate buffers.
outputNames = outputName.split(",");
intValues = new int[inputSize * inputSize];
floatValues = new float[inputSize * inputSize * 3];
this.blockSize = blockSize;
d.outputNames = outputName.split(",");
d.intValues = new int[inputSize * inputSize];
d.floatValues = new float[inputSize * inputSize * 3];
d.blockSize = blockSize;
inferenceInterface = new TensorFlowInferenceInterface();
d.inferenceInterface = new TensorFlowInferenceInterface();
// Graphs must be converted from https://github.com/thtrieu/darkflow
return inferenceInterface.initializeTensorFlow(assetManager, modelFilename);
final int status = d.inferenceInterface.initializeTensorFlow(assetManager, modelFilename);
if (status != 0) {
LOGGER.e("TF init status: " + status);
throw new RuntimeException("TF init status (" + status + ") != 0");
}
return d;
}
private TensorFlowYoloDetector() {}
private float expit(final float x) {
return (float) (1. / (1. + Math.exp(-x)));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册