提交 8ec03d57 编写于 作者: S SunAhong1993

fix the --caffe_proto

上级 d6af8cf0
...@@ -48,7 +48,7 @@ def arg_parser(): ...@@ -48,7 +48,7 @@ def arg_parser():
parser.add_argument("--caffe_proto", parser.add_argument("--caffe_proto",
"-c", "-c",
type=_text_type, type=_text_type,
default=None, default='./x2paddle/decoder/caffe_pb2.py',
help="caffe proto file of caffe model") help="caffe proto file of caffe model")
parser.add_argument("--version", parser.add_argument("--version",
"-v", "-v",
...@@ -92,7 +92,8 @@ def tf2paddle(model_path, save_dir): ...@@ -92,7 +92,8 @@ def tf2paddle(model_path, save_dir):
def caffe2paddle(proto, weight, save_dir, caffe_proto): def caffe2paddle(proto, weight, save_dir, caffe_proto):
if caffe_proto is not None: if caffe_proto is not None:
import os import os
if not os.path.isfile(caffe_proto + 'caffe_pb2.py'): print(caffe_proto)
if caffe_proto != 'None' and not os.path.isfile(caffe_proto):
print("The file that resolve caffe is not exist.") print("The file that resolve caffe is not exist.")
return return
else: else:
......
...@@ -22,18 +22,21 @@ from x2paddle.op_mapper import caffe_shape ...@@ -22,18 +22,21 @@ from x2paddle.op_mapper import caffe_shape
class CaffeResolver(object): class CaffeResolver(object):
def __init__(self, caffe_proto_folder=None): def __init__(self, caffe_proto):
self.proto_path = caffe_proto_folder self.proto_path = caffe_proto
if self.proto_path == None: if self.proto_path == 'None':
self.use_default = True self.use_default = True
else: else:
self.use_default = False self.use_default = False
self.import_caffe() self.import_caffe()
def import_caffepb(self): def import_caffepb(self):
sys.path.append(self.proto_path) (filepath,
import caffe_pb2 tempfilename) = os.path.split(os.path.abspath(self.proto_path))
return caffe_pb2 (filename, extension) = os.path.splitext(tempfilename)
sys.path.append(filepath)
out = __import__(filename)
return out
def import_caffe(self): def import_caffe(self):
self.caffe = None self.caffe = None
...@@ -139,9 +142,16 @@ class CaffeGraph(Graph): ...@@ -139,9 +142,16 @@ class CaffeGraph(Graph):
dim=[dims[0], dims[1], dims[2], dims[3] dim=[dims[0], dims[1], dims[2], dims[3]
]))).to_proto().layer[0]) ]))).to_proto().layer[0])
except: except:
raise ImportError( print(
'The .proto file does not work for the old style prototxt. You must install the caffe or modify the old style to new style in .protottx file.' "The .py fiel compiled by .proto file does not work for the old style prototxt. "
) )
print("There are 2 solutions for you as below:")
print(
"1. install caffe and set \'--caffe_proto=None\'.")
print(
"2. modify your .prototxt from the old style to the new style."
)
sys.exit(-1)
data.name = self.model.input[i] data.name = self.model.input[i]
data.top[0] = self.model.input[i] data.top[0] = self.model.input[i]
else: else:
...@@ -155,9 +165,16 @@ class CaffeGraph(Graph): ...@@ -155,9 +165,16 @@ class CaffeGraph(Graph):
dim=[dims[0], dims[1], dims[2], dims[3] dim=[dims[0], dims[1], dims[2], dims[3]
]))).to_proto().layer[0]) ]))).to_proto().layer[0])
except: except:
raise ImportError( print(
'The .proto file does not work for the old style prototxt. You must install the caffe or modify the old style to new style in .protottx file.' "The .py fiel compiled by .proto file does not work for the old style prototxt. "
)
print("There are 2 solutions for you as below:")
print(
"1. install caffe and set \'--caffe_proto=None\'.")
print(
"2. modify your .prototxt from the old style to the new style."
) )
sys.exit(-1)
data.name = self.model.input[i] data.name = self.model.input[i]
data.top[0] = self.model.input[i] data.top[0] = self.model.input[i]
layers = [data] + layers layers = [data] + layers
...@@ -202,11 +219,14 @@ class CaffeGraph(Graph): ...@@ -202,11 +219,14 @@ class CaffeGraph(Graph):
class CaffeDecoder(object): class CaffeDecoder(object):
def __init__(self, proto_path, model_path, caffe_proto_folder=None): def __init__(self,
proto_path,
model_path,
caffe_proto='./x2paddle/decoder/caffe_pb2.py'):
self.proto_path = proto_path self.proto_path = proto_path
self.model_path = model_path self.model_path = model_path
self.resolver = CaffeResolver(caffe_proto_folder=caffe_proto_folder) self.resolver = CaffeResolver(caffe_proto=caffe_proto)
self.net = self.resolver.NetParameter() self.net = self.resolver.NetParameter()
with open(proto_path, 'rb') as proto_file: with open(proto_path, 'rb') as proto_file:
proto_str = proto_file.read() proto_str = proto_file.read()
......
因为 它太大了无法显示 source diff 。你可以改为 查看blob
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册