imagenet_cls_test_inception.py 3.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
import numpy as np
import sys
import os
import argparse
import tensorflow as tf
from tensorflow.python.platform import gfile
from imagenet_cls_test_alexnet import MeanValueFetch, DnnCaffeModel, Framework, ClsAccEvaluation
try:
    import cv2 as cv
except ImportError:
11
    raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, '
L
luz.paz 已提交
12
                      'configure environment variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)')
13 14

# If you've got an exception "Cannot load libmkl_avx.so or libmkl_def.so" or similar, try to export next variable
L
luz.paz 已提交
15
# before running the script:
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
# LD_PRELOAD=/opt/intel/mkl/lib/intel64/libmkl_core.so:/opt/intel/mkl/lib/intel64/libmkl_sequential.so


class TensorflowModel(Framework):
    sess = tf.Session
    output = tf.Graph

    def __init__(self, model_file, in_blob_name, out_blob_name):
        self.in_blob_name = in_blob_name
        self.sess = tf.Session()
        with gfile.FastGFile(model_file, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            self.sess.graph.as_default()
            tf.import_graph_def(graph_def, name='')
        self.output = self.sess.graph.get_tensor_by_name(out_blob_name + ":0")

    def get_name(self):
        return 'Tensorflow'

    def get_output(self, input_blob):
        assert len(input_blob.shape) == 4
        batch_tf = input_blob.transpose(0, 2, 3, 1)
        out = self.sess.run(self.output,
                       {self.in_blob_name+':0': batch_tf})
        out = out[..., 1:1001]
        return out


class DnnTfInceptionModel(DnnCaffeModel):
    net = cv.dnn.Net()

    def __init__(self, model_file, in_blob_name, out_blob_name):
        self.net = cv.dnn.readNetFromTensorflow(model_file)
        self.in_blob_name = in_blob_name
        self.out_blob_name = out_blob_name

    def get_output(self, input_blob):
        return super(DnnTfInceptionModel, self).get_output(input_blob)[..., 1:1001]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--imgs_dir", help="path to ImageNet validation subset images dir, ILSVRC2012_img_val dir")
    parser.add_argument("--img_cls_file", help="path to file with classes ids for images, download it here:"
A
Alexander Alekhin 已提交
61
                            "https://github.com/opencv/opencv_extra/tree/3.4/testdata/dnn/img_classes_inception.txt")
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    parser.add_argument("--model", help="path to tensorflow model, download it here:"
                                        "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip")
    parser.add_argument("--log", help="path to logging file")
    parser.add_argument("--batch_size", help="size of images in batch", default=1)
    parser.add_argument("--frame_size", help="size of input image", default=224)
    parser.add_argument("--in_blob", help="name for input blob", default='input')
    parser.add_argument("--out_blob", help="name for output blob", default='softmax2')
    args = parser.parse_args()

    data_fetcher = MeanValueFetch(args.frame_size, args.imgs_dir, True)

    frameworks = [TensorflowModel(args.model, args.in_blob, args.out_blob),
                  DnnTfInceptionModel(args.model, '', args.out_blob)]

    acc_eval = ClsAccEvaluation(args.log, args.img_cls_file, args.batch_size)
    acc_eval.process(frameworks, data_fetcher)