resolver.py 2.0 KB
Newer Older
R
Renwb1991 已提交
1 2
import os
import sys
S
SunAhong 已提交
3
import subprocess
R
Renwb1991 已提交
4 5 6 7 8

SHARED_CAFFE_RESOLVER = None


def import_caffepb():
S
SunAhong1993 已提交
9
   p = os.path.realpath(__file__)
R
Renwb1991 已提交
10 11 12
    p = os.path.dirname(p)
    p = os.path.join(p, '../../proto')
    sys.path.insert(0, p)
S
SunAhong1993 已提交
13 14 15 16 17 18 19
    s = sys.version
    if s.startswith('2'):
        import commands
        pb_version = commands.getstatusoutput('protoc --version')[1]
    else:
        import subprocess
        pb_version = subprocess.getstatusoutput('protoc --version')[1]
S
SunAhong 已提交
20 21
    ver_str = pb_version.split(' ')[-1].replace('.', '')
    ver_int = int(ver_str)
S
SunAhong1993 已提交
22
    assert ver_int >= 360, 'The version of protobuf must be larger than 3.6.0!'
R
Renwb1991 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    import caffe_pb2
    return caffe_pb2


class CaffeResolver(object):
    def __init__(self):
        self.import_caffe()

    def import_caffe(self):
        self.caffe = None
        try:
            # Try to import PyCaffe first
            import caffe
            self.caffe = caffe
        except ImportError:
            # Fall back to the protobuf implementation
            self.caffepb = import_caffepb()
            show_fallback_warning()
        if self.caffe:
            # Use the protobuf code from the imported distribution.
            # This way, Caffe variants with custom layers will work.
            self.caffepb = self.caffe.proto.caffe_pb2
        self.NetParameter = self.caffepb.NetParameter

    def has_pycaffe(self):
        return self.caffe is not None


def get_caffe_resolver():
    global SHARED_CAFFE_RESOLVER
    if SHARED_CAFFE_RESOLVER is None:
        SHARED_CAFFE_RESOLVER = CaffeResolver()
    return SHARED_CAFFE_RESOLVER


def has_pycaffe():
    return get_caffe_resolver().has_pycaffe()


def show_fallback_warning():
    msg = '''
------------------------------------------------------------
    WARNING: PyCaffe not found!
    Falling back to a pure protocol buffer implementation.
    * Conversions will be drastically slower.
------------------------------------------------------------

'''
    sys.stderr.write(msg)