diff --git a/x2paddle/decoder/caffe_decoder.py b/x2paddle/decoder/caffe_decoder.py index d6a925170d259332e7d1cc7f89bf935696ceca16..e4bd86c4833944ea325868f5770b61e836b90843 100644 --- a/x2paddle/decoder/caffe_decoder.py +++ b/x2paddle/decoder/caffe_decoder.py @@ -27,11 +27,18 @@ class CaffeResolver(object): self.import_caffe() def import_caffepb(self): - (filepath, - tempfilename) = os.path.split(os.path.abspath(self.caffe_proto)) - (filename, extension) = os.path.splitext(tempfilename) - sys.path.append(filepath) - out = __import__(filename) + if self.caffe_proto is None: + from x2paddle.decoder import caffe_pb2 + out = caffe_pb2 + else: + if not os.path.isfile(self.caffe_proto): + raise Exception( + "The .py file compiled by caffe.proto is not exist.") + (filepath, + tempfilename) = os.path.split(os.path.abspath(self.caffe_proto)) + (filename, extension) = os.path.splitext(tempfilename) + sys.path.append(filepath) + out = __import__(filename) return out def import_caffe(self): @@ -146,11 +153,6 @@ class CaffeDecoder(object): def __init__(self, proto_path, model_path, caffe_proto): self.proto_path = proto_path self.model_path = model_path - if caffe_proto is None: - caffe_proto = './x2paddle/decoder/caffe_pb2.py' - if caffe_proto is not None and not os.path.isfile(caffe_proto): - raise Exception( - "The .py file compiled by caffe.proto is not exist.") self.resolver = CaffeResolver(caffe_proto=caffe_proto) self.net = self.resolver.NetParameter()