未验证 提交 1a00c049 编写于 作者: J Jason 提交者: GitHub

Merge pull request #80 from SunAhong1993/develop

fix the --caffe_proto
...@@ -40,7 +40,7 @@ x2paddle --framework=caffe --prototxt=deploy.proto --weight=deploy.caffemodel -- ...@@ -40,7 +40,7 @@ x2paddle --framework=caffe --prototxt=deploy.proto --weight=deploy.caffemodel --
|--weight | 当framework为caffe时,该参数指定caffe模型的参数文件路径 | |--weight | 当framework为caffe时,该参数指定caffe模型的参数文件路径 |
|--save_dir | 指定转换后的模型保存目录路径 | |--save_dir | 指定转换后的模型保存目录路径 |
|--model | 当framework为tensorflow时,该参数指定tensorflow的pb模型文件路径 | |--model | 当framework为tensorflow时,该参数指定tensorflow的pb模型文件路径 |
|--caffe_proto | [可选]由caffe.proto编译成caffe_pb2.py文件的存放路径,当没有安装caffe或者使用自定义Layer时使用 | |--caffe_proto | [可选]由caffe.proto编译成caffe_pb2.py文件的存放路径,当没有安装caffe或者使用自定义Layer时使用,默认为None |
## 使用转换后的模型 ## 使用转换后的模型
转换后的模型包括`model_with_code``inference_model`两个目录。 转换后的模型包括`model_with_code``inference_model`两个目录。
......
...@@ -45,11 +45,12 @@ def arg_parser(): ...@@ -45,11 +45,12 @@ def arg_parser():
type=_text_type, type=_text_type,
default=None, default=None,
help="define which deeplearning framework") help="define which deeplearning framework")
parser.add_argument("--caffe_proto", parser.add_argument(
"-c", "--caffe_proto",
type=_text_type, "-c",
default=None, type=_text_type,
help="caffe proto file of caffe model") default=None,
help="the .py file compiled by caffe proto file of caffe model")
parser.add_argument("--version", parser.add_argument("--version",
"-v", "-v",
action="store_true", action="store_true",
...@@ -92,8 +93,8 @@ def tf2paddle(model_path, save_dir): ...@@ -92,8 +93,8 @@ def tf2paddle(model_path, save_dir):
def caffe2paddle(proto, weight, save_dir, caffe_proto): def caffe2paddle(proto, weight, save_dir, caffe_proto):
if caffe_proto is not None: if caffe_proto is not None:
import os import os
if not os.path.isfile(caffe_proto + 'caffe_pb2.py'): if caffe_proto is not None and not os.path.isfile(caffe_proto):
print("The file that resolve caffe is not exist.") print("The .py file compiled by caffe.proto is not exist.")
return return
else: else:
try: try:
......
...@@ -22,18 +22,21 @@ from x2paddle.op_mapper import caffe_shape ...@@ -22,18 +22,21 @@ from x2paddle.op_mapper import caffe_shape
class CaffeResolver(object): class CaffeResolver(object):
def __init__(self, caffe_proto_folder=None): def __init__(self, caffe_proto):
self.proto_path = caffe_proto_folder self.proto_path = caffe_proto
if self.proto_path == None: if self.proto_path is None:
self.use_default = True self.use_default = True
else: else:
self.use_default = False self.use_default = False
self.import_caffe() self.import_caffe()
def import_caffepb(self): def import_caffepb(self):
sys.path.append(self.proto_path) (filepath,
import caffe_pb2 tempfilename) = os.path.split(os.path.abspath(self.proto_path))
return caffe_pb2 (filename, extension) = os.path.splitext(tempfilename)
sys.path.append(filepath)
out = __import__(filename)
return out
def import_caffe(self): def import_caffe(self):
self.caffe = None self.caffe = None
...@@ -139,9 +142,17 @@ class CaffeGraph(Graph): ...@@ -139,9 +142,17 @@ class CaffeGraph(Graph):
dim=[dims[0], dims[1], dims[2], dims[3] dim=[dims[0], dims[1], dims[2], dims[3]
]))).to_proto().layer[0]) ]))).to_proto().layer[0])
except: except:
raise ImportError( print(
'The .proto file does not work for the old style prototxt. You must install the caffe or modify the old style to new style in .protottx file.' "The .py file compiled by .proto file does not work for the old style prototxt. "
) )
print("There are 2 solutions for you as below:")
print(
"1. install caffe and don\'t set \'--caffe_proto\'."
)
print(
"2. modify your .prototxt from the old style to the new style."
)
sys.exit(-1)
data.name = self.model.input[i] data.name = self.model.input[i]
data.top[0] = self.model.input[i] data.top[0] = self.model.input[i]
else: else:
...@@ -155,9 +166,17 @@ class CaffeGraph(Graph): ...@@ -155,9 +166,17 @@ class CaffeGraph(Graph):
dim=[dims[0], dims[1], dims[2], dims[3] dim=[dims[0], dims[1], dims[2], dims[3]
]))).to_proto().layer[0]) ]))).to_proto().layer[0])
except: except:
raise ImportError( print(
'The .proto file does not work for the old style prototxt. You must install the caffe or modify the old style to new style in .protottx file.' "The .py file compiled by .proto file does not work for the old style prototxt. "
)
print("There are 2 solutions for you as below:")
print(
"1. install caffe and don\'t set \'--caffe_proto\'."
)
print(
"2. modify your .prototxt from the old style to the new style."
) )
sys.exit(-1)
data.name = self.model.input[i] data.name = self.model.input[i]
data.top[0] = self.model.input[i] data.top[0] = self.model.input[i]
layers = [data] + layers layers = [data] + layers
...@@ -202,11 +221,11 @@ class CaffeGraph(Graph): ...@@ -202,11 +221,11 @@ class CaffeGraph(Graph):
class CaffeDecoder(object): class CaffeDecoder(object):
def __init__(self, proto_path, model_path, caffe_proto_folder=None): def __init__(self, proto_path, model_path, caffe_proto=None):
self.proto_path = proto_path self.proto_path = proto_path
self.model_path = model_path self.model_path = model_path
self.resolver = CaffeResolver(caffe_proto_folder=caffe_proto_folder) self.resolver = CaffeResolver(caffe_proto=caffe_proto)
self.net = self.resolver.NetParameter() self.net = self.resolver.NetParameter()
with open(proto_path, 'rb') as proto_file: with open(proto_path, 'rb') as proto_file:
proto_str = proto_file.read() proto_str = proto_file.read()
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册