提交 2273e96a 编写于 作者: L liuyang11

support lenet and resnet convertion

### caffe2fluid ### caffe2fluid
this tool is used to convert a caffe-model to paddle-model(fluid api) this tool is used to convert a caffe-model to paddle-model(fluid api)
### howto ### howto
0, prepare caffepb.py in ./proto 1, prepare caffepb.py in ./proto, two options provided
option 1: generate it from caffe.proto using protoc option 1: generate it from caffe.proto using protoc
bash ./proto/compile.sh bash ./proto/compile.sh
option2: download one from github directly option2: download one from github directly
cd proto/ && wget https://github.com/ethereon/caffe-tensorflow/blob/master/kaffe/caffe/caffepb.py cd proto/ && wget https://github.com/ethereon/caffe-tensorflow/blob/master/kaffe/caffe/caffepb.py
1, convert you caffe model using convert.py which will generate a python code and weight(in .npy) 2, convert you caffe model using convert.py which will generate a python code and weight(in .npy)
2, use the converted model to predict 3, use the converted model to predict
(see more detail info in 'tests/lenet/README.md') (see more detail info in 'tests/lenet/README.md')
......
import os
import sys import sys
SHARED_CAFFE_RESOLVER = None SHARED_CAFFE_RESOLVER = None
def import_caffepb():
p = os.path.realpath(__file__)
p = os.path.dirname(p)
p = os.path.join(p, '../../proto')
sys.path.insert(0, p)
import caffepb
return caffepb
class CaffeResolver(object): class CaffeResolver(object):
def __init__(self): def __init__(self):
self.import_caffe() self.import_caffe()
...@@ -15,8 +25,7 @@ class CaffeResolver(object): ...@@ -15,8 +25,7 @@ class CaffeResolver(object):
self.caffe = caffe self.caffe = caffe
except ImportError: except ImportError:
# Fall back to the protobuf implementation # Fall back to the protobuf implementation
from . import caffepb self.caffepb = import_caffepb()
self.caffepb = caffepb
show_fallback_warning() show_fallback_warning()
if self.caffe: if self.caffe:
# Use the protobuf code from the imported distribution. # Use the protobuf code from the imported distribution.
......
...@@ -165,7 +165,6 @@ class Network(object): ...@@ -165,7 +165,6 @@ class Network(object):
# Get the number of channels in the input # Get the number of channels in the input
h_i, w_i = input.shape[2:] h_i, w_i = input.shape[2:]
fluid = import_fluid() fluid = import_fluid()
output = fluid.layers.pool2d( output = fluid.layers.pool2d(
input=input, input=input,
...@@ -182,7 +181,6 @@ class Network(object): ...@@ -182,7 +181,6 @@ class Network(object):
# Get the number of channels in the input # Get the number of channels in the input
h_i, w_i = input.shape[2:] h_i, w_i = input.shape[2:]
fluid = import_fluid() fluid = import_fluid()
output = fluid.layers.pool2d( output = fluid.layers.pool2d(
input=input, input=input,
......
...@@ -273,7 +273,6 @@ class TensorFlowEmitter(object): ...@@ -273,7 +273,6 @@ class TensorFlowEmitter(object):
b += self.emit_node(node) b += self.emit_node(node)
blocks.append(b[:-1]) blocks.append(b[:-1])
s = s + '\n\n'.join(blocks) s = s + '\n\n'.join(blocks)
s += self.emit_convert_def(input_nodes) s += self.emit_convert_def(input_nodes)
s += self.emit_main_def(name) s += self.emit_main_def(name)
return s return s
......
### convert lenet model from caffe format into paddle format(fluid api) ### convert lenet model from caffe format into paddle format(fluid api)
### howto ### howto
0, prepare your caffepb.py
1, download a lenet caffe-model 1, download a lenet caffe-model
lenet_iter_10000.caffemodel lenet_iter_10000.caffemodel
download address: https://github.com/ethereon/caffe-tensorflow/raw/master/examples/mnist/lenet_iter_10000.caffemodel download address: https://github.com/ethereon/caffe-tensorflow/raw/master/examples/mnist/lenet_iter_10000.caffemodel
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#function: #function:
# convert a caffe model # convert a caffe model
# eg:
# bash ./convert.sh ./model.caffe/lenet.prototxt ./model.caffe/lenet.caffemodel lenet.py lenet.npy
if [[ $# -ne 4 ]];then if [[ $# -ne 4 ]];then
echo "usage:" echo "usage:"
......
...@@ -61,8 +61,8 @@ class Network(object): ...@@ -61,8 +61,8 @@ class Network(object):
fluid = import_fluid() fluid = import_fluid()
#load fluid mode directly #load fluid mode directly
if os.path.isdir(data_path): if os.path.isdir(data_path):
assert ( assert (exe is not None), \
exe is not None), 'must provide a executor to load fluid model' 'must provide a executor to load fluid model'
fluid.io.load_persistables_if_exist(executor=exe, dirname=data_path) fluid.io.load_persistables_if_exist(executor=exe, dirname=data_path)
return True return True
...@@ -167,7 +167,6 @@ class Network(object): ...@@ -167,7 +167,6 @@ class Network(object):
# Get the number of channels in the input # Get the number of channels in the input
h_i, w_i = input.shape[2:] h_i, w_i = input.shape[2:]
fluid = import_fluid() fluid = import_fluid()
output = fluid.layers.pool2d( output = fluid.layers.pool2d(
input=input, input=input,
...@@ -184,7 +183,6 @@ class Network(object): ...@@ -184,7 +183,6 @@ class Network(object):
# Get the number of channels in the input # Get the number of channels in the input
h_i, w_i = input.shape[2:] h_i, w_i = input.shape[2:]
fluid = import_fluid() fluid = import_fluid()
output = fluid.layers.pool2d( output = fluid.layers.pool2d(
input=input, input=input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册