提交 4ad6b207 编写于 作者: W wanglong03

change demos to example folder

上级 7ac1c7d6
### Convert lenet model from caffe format into paddle format(fluid api)
### Howto
1, Prepare your caffepb.py
2, Download a lenet caffe-model
lenet_iter_10000.caffemodel
download address: https://github.com/ethereon/caffe-tensorflow/raw/master/examples/mnist/lenet_iter_10000.caffemodel
md5: cbec75c1c374b6c1981c4a1eb024ae01
lenet.prototxt
download address: https://raw.githubusercontent.com/BVLC/caffe/master/examples/mnist/lenet.prototxt
md5: 27384af843338ab90b00c8d1c81de7d5
2, Convert this model(make sure caffepb.py is ready in ../../proto)
convert to npy format
bash ./convert.sh lenet.prototxt lenet.caffemodel lenet.py lenet.npy
save to fluid format(optional)
bash ./convert.sh lenet.prototxt lenet.caffemodel lenet.py lenet.npy && python ./lenet.py ./lenet.npy ./fluid.model
4, Use this new model(paddle installed in this python)
use fluid format
python ./predict.py ./fluid.model
use npy format
python ./predict.py ./lenet.npy
#!/bin/bash
#function:
# convert a caffe model
# eg:
# bash ./convert.sh ./model.caffe/lenet.prototxt ./model.caffe/lenet.caffemodel lenet.py lenet.npy
if [[ $# -ne 4 ]];then
echo "usage:"
echo " bash $0 [PROTOTXT] [CAFFEMODEL] [PY_NAME] [WEIGHT_NAME]"
echo " eg: bash $0 lenet.prototxt lenet.caffemodel lenet.py lenet.npy"
exit 1
fi
WORK_ROOT=$(dirname `readlink -f ${BASH_SOURCE[0]}`)
if [[ -z $PYTHON ]];then
PYTHON=`which python`
fi
PROTOTXT=$1
CAFFEMODEL=$2
PY_NAME=$3
WEIGHT_NAME=$4
CONVERTER_PY="$WORK_ROOT/../../convert.py"
$PYTHON $CONVERTER_PY $PROTOTXT --caffemodel $CAFFEMODEL --code-output-path=$PY_NAME --data-output-path=$WEIGHT_NAME
ret=$?
if [[ $ret -eq 0 ]];then
echo "succeed to convert caffe model[$CAFFEMODEL, $PROTOTXT] to paddle model[$PY_NAME, $WEIGHT_NAME]"
else
echo "failed to convert caffe model[$CAFFEMODEL, $PROTOTXT]"
fi
exit $ret
### generated by caffe2fluid, your net is in class "LeNet" ###
import math
import os
import numpy as np
def import_fluid():
import paddle.v2.fluid as fluid
return fluid
def layer(op):
'''Decorator for composable network layers.'''
def layer_decorated(self, *args, **kwargs):
# Automatically set a name if not provided.
name = kwargs.setdefault('name', self.get_unique_name(op.__name__))
# Figure out the layer inputs.
if len(self.terminals) == 0:
raise RuntimeError('No input variables found for layer %s.' % name)
elif len(self.terminals) == 1:
layer_input = self.terminals[0]
else:
layer_input = list(self.terminals)
# Perform the operation and get the output.
layer_output = op(self, layer_input, *args, **kwargs)
# Add to layer LUT.
self.layers[name] = layer_output
# This output is now the input for the next layer.
self.feed(layer_output)
# Return self for chained calls.
return self
return layer_decorated
class Network(object):
def __init__(self, inputs, trainable=True):
# The input nodes for this network
self.inputs = inputs
# The current list of terminal nodes
self.terminals = []
# Mapping from layer names to layers
self.layers = dict(inputs)
# If true, the resulting variables are set as trainable
self.trainable = trainable
# Switch variable for dropout
self.paddle_env = None
self.setup()
def setup(self):
'''Construct the network. '''
raise NotImplementedError('Must be implemented by the subclass.')
def load(self, data_path, exe=None, place=None, ignore_missing=False):
'''Load network weights.
data_path: The path to the numpy-serialized network weights
ignore_missing: If true, serialized weights for missing layers are ignored.
'''
fluid = import_fluid()
#load fluid mode directly
if os.path.isdir(data_path):
assert (exe is not None), \
'must provide a executor to load fluid model'
fluid.io.load_persistables_if_exist(executor=exe, dirname=data_path)
return True
#load model from a npy file
if exe is None or place is None:
if self.paddle_env is None:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
self.paddle_env = {'place': place, 'exe': exe}
exe = exe.run(fluid.default_startup_program())
else:
place = self.paddle_env['place']
exe = self.paddle_env['exe']
data_dict = np.load(data_path).item()
for op_name in data_dict:
layer = self.layers[op_name]
for param_name, data in data_dict[op_name].iteritems():
try:
name = '%s_%s' % (op_name, param_name)
v = fluid.global_scope().find_var(name)
w = v.get_tensor()
w.set(data, place)
except ValueError:
if not ignore_missing:
raise
return True
def feed(self, *args):
'''Set the input(s) for the next operation by replacing the terminal nodes.
The arguments can be either layer names or the actual layers.
'''
assert len(args) != 0
self.terminals = []
for fed_layer in args:
if isinstance(fed_layer, basestring):
try:
fed_layer = self.layers[fed_layer]
except KeyError:
raise KeyError('Unknown layer name fed: %s' % fed_layer)
self.terminals.append(fed_layer)
return self
def get_output(self):
'''Returns the current network output.'''
return self.terminals[-1]
def get_unique_name(self, prefix):
'''Returns an index-suffixed unique name for the given prefix.
This is used for auto-generating layer names based on the type-prefix.
'''
ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1
return '%s_%d' % (prefix, ident)
@layer
def conv(self,
input,
k_h,
k_w,
c_o,
s_h,
s_w,
name,
relu=True,
padding=None,
group=1,
biased=True):
if padding is None:
padding = [0, 0]
# Get the number of channels in the input
c_i, h_i, w_i = input.shape[1:]
# Verify that the grouping parameter is valid
assert c_i % group == 0
assert c_o % group == 0
fluid = import_fluid()
prefix = name + '_'
output = fluid.layers.conv2d(
input=input,
filter_size=[k_h, k_w],
num_filters=c_o,
stride=[s_h, s_w],
padding=padding,
groups=group,
param_attr=fluid.ParamAttr(name=prefix + "weights"),
bias_attr=fluid.ParamAttr(name=prefix + "biases"),
act="relu" if relu is True else None)
return output
@layer
def relu(self, input, name):
fluid = import_fluid()
output = fluid.layers.relu(x=input)
return output
@layer
def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=None):
if padding is None:
padding = [0, 0]
# Get the number of channels in the input
h_i, w_i = input.shape[2:]
fluid = import_fluid()
output = fluid.layers.pool2d(
input=input,
pool_size=[k_h, k_w],
pool_stride=[s_h, s_w],
pool_padding=padding,
pool_type='max')
return output
@layer
def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=None):
if padding is None:
padding = [0, 0]
# Get the number of channels in the input
h_i, w_i = input.shape[2:]
fluid = import_fluid()
output = fluid.layers.pool2d(
input=input,
pool_size=[k_h, k_w],
pool_stride=[s_h, s_w],
pool_padding=padding,
pool_type='avg')
return output
@layer
def lrn(self, input, radius, alpha, beta, name, bias=1.0):
raise Exception('lrn() not implemented yet')
@layer
def concat(self, inputs, axis, name):
fluid = import_fluid()
output = fluid.layers.concat(input=inputs, axis=axis)
return output
@layer
def add(self, inputs, name):
fluid = import_fluid()
output = inputs[0]
for i in inputs[1:]:
output = fluid.layers.elementwise_add(x=output, y=i)
return output
@layer
def fc(self, input, num_out, name, relu=True, act=None):
fluid = import_fluid()
if act is None:
act = 'relu' if relu is True else None
prefix = name + '_'
output = fluid.layers.fc(
name=name,
input=input,
size=num_out,
act=act,
param_attr=fluid.ParamAttr(name=prefix + 'weights'),
bias_attr=fluid.ParamAttr(name=prefix + 'biases'))
return output
@layer
def softmax(self, input, name):
fluid = import_fluid()
output = fluid.layers.softmax(x=input, name=name)
return output
@layer
def batch_normalization(self, input, name, scale_offset=True, relu=False):
# NOTE: Currently, only inference is supported
fluid = import_fluid()
prefix = name + '_'
param_attr = None if scale_offset is False else fluid.ParamAttr(
name=prefix + 'scale')
bias_attr = None if scale_offset is False else fluid.ParamAttr(
name=prefix + 'offset')
mean_name = prefix + 'mean'
variance_name = prefix + 'variance'
output = fluid.layers.batch_norm(
name=name,
input=input,
is_test=True,
param_attr=param_attr,
bias_attr=bias_attr,
moving_mean_name=mean_name,
moving_variance_name=variance_name,
epsilon=1e-5,
act='relu' if relu is True else None)
return output
@layer
def dropout(self, input, keep_prob, name):
raise Exception('dropout() not implemented yet')
class LeNet(Network):
def setup(self):
self.feed('data')
self.conv(5, 5, 20, 1, 1, relu=False, name='conv1')
self.max_pool(2, 2, 2, 2, name='pool1')
self.conv(5, 5, 50, 1, 1, relu=False, name='conv2')
self.max_pool(2, 2, 2, 2, name='pool2')
self.fc(500, name='ip1')
self.fc(10, relu=False, name='ip2')
self.softmax(name='prob')
@classmethod
def convert(cls, npy_model, fluid_path):
import paddle.v2.fluid as fluid
data_layer = fluid.layers.data(
name="data", shape=[1, 28, 28], dtype="float32")
feed_data = {"data": data_layer}
net = cls(feed_data)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
net.load(data_path=npy_model, exe=exe, place=place)
fluid.io.save_persistables(executor=exe, dirname=fluid_path)
if __name__ == "__main__":
#usage: python xxxnet.py xxx.npy ./model
import sys
npy_weight = sys.argv[1]
fluid_model = sys.argv[2]
LeNet.convert(npy_weight, fluid_model)
exit(0)
#!/bin/env python
#function:
# demo to show how to use converted model using caffe2fluid
#
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
from lenet import LeNet as MyNet
def test_model(exe, test_program, fetch_list, test_reader, feeder):
acc_set = []
for data in test_reader():
acc_np, pred = exe.run(program=test_program,
feed=feeder.feed(data),
fetch_list=fetch_list)
acc_set.append(float(acc_np))
acc_val = np.array(acc_set).mean()
return float(acc_val)
def main(model_path):
""" main
"""
print('load fluid model in %s' % (model_path))
with_gpu = False
paddle.init(use_gpu=with_gpu)
#1, define network topology
images = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
net = MyNet({'data': images})
prediction = net.layers['prob']
acc = fluid.layers.accuracy(input=prediction, label=label)
place = fluid.CUDAPlace(0) if with_gpu is True else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
#2, load weights
if model_path.find('.npy') > 0:
net.load(data_path=model_path, exe=exe, place=place)
else:
net.load(data_path=model_path, exe=exe)
#3, test this model
test_program = fluid.default_main_program().clone()
test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
fetch_list = [acc, prediction]
print('go to test model using test set')
acc_val = test_model(exe, test_program, \
fetch_list, test_reader, feeder)
print('test accuracy is [%.4f], expected value[0.919]' % (acc_val))
if __name__ == "__main__":
import sys
if len(sys.argv) == 2:
fluid_model_path = sys.argv[1]
else:
fluid_model_path = './model.fluid'
main(fluid_model_path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册