未验证 提交 1907be42 编写于 作者: J Jason 提交者: GitHub

Merge pull request #1 from Renwb1991/caffe2fluid

X2Paddle: add caffe2fluid
proto/caffepb.py
proto/caffe_pb2.py
### Caffe2Fluid
This tool is used to convert a Caffe model to a Fluid model
### Key Features
1. Convert caffe model to fluid model with codes of defining a network(useful for re-training)
2. Pycaffe is not necessary when just want convert model without do caffe-inference
3. Caffe's customized layers convertion also be supported by extending this tool
4. A bunch of tools in `examples/imagenet/tools` are provided to compare the difference
### HowTo
1. Prepare `caffepb.py` in `./proto` if your python has no `pycaffe` module, two options provided here:
- Generate pycaffe from caffe.proto
```
bash ./proto/compile.sh
```
- Download one from github directly
```
cd proto/ && wget https://raw.githubusercontent.com/ethereon/caffe-tensorflow/master/kaffe/caffe/caffepb.py
```
2. Convert the Caffe model to Fluid model
- Generate fluid code and weight file
```
python convert.py alexnet.prototxt \
--caffemodel alexnet.caffemodel \
--data-output-path alexnet.npy \
--code-output-path alexnet.py
```
- Save weights as fluid model file
```
# only infer the last layer's result
python alexnet.py alexnet.npy ./fluid
# infer these 2 layer's result
python alexnet.py alexnet.npy ./fluid fc8,prob
```
3. Use the converted model to infer
- See more details in `examples/imagenet/tools/run.sh`
4. Compare the inference results with caffe
- See more details in `examples/imagenet/tools/diff.sh`
### How to convert custom layer
1. Implement your custom layer in a file under `kaffe/custom_layers`, eg: mylayer.py
- Implement ```shape_func(input_shape, [other_caffe_params])``` to calculate the output shape
- Implement ```layer_func(inputs, name, [other_caffe_params])``` to construct a fluid layer
- Register these two functions ```register(kind='MyType', shape=shape_func, layer=layer_func)```
- Notes: more examples can be found in `kaffe/custom_layers`
2. Add ```import mylayer``` to `kaffe/custom_layers/\_\_init__.py`
3. Prepare your pycaffe as your customized version(same as previous env prepare)
- (option1) replace `proto/caffe.proto` with your own caffe.proto and compile it
- (option2) change your `pycaffe` to the customized version
4. Convert the Caffe model to Fluid model
5. Set env $CAFFE2FLUID_CUSTOM_LAYERS to the parent directory of 'custom_layers'
```
export CAFFE2FLUID_CUSTOM_LAYERS=/path/to/caffe2fluid/kaffe
```
6. Use the converted model when loading model in `xxxnet.py` and `xxxnet.npy`(no need if model is already in `fluid/model` and `fluid/params`)
### Tested models
- Lenet:
[model addr](https://github.com/ethereon/caffe-tensorflow/blob/master/examples/mnist)
- ResNets:(ResNet-50, ResNet-101, ResNet-152)
[model addr](https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777)
- GoogleNet:
[model addr](https://gist.github.com/jimmie33/7ea9f8ac0da259866b854460f4526034)
- VGG:
[model addr](https://gist.github.com/ksimonyan/211839e770f7b538e2d8)
- AlexNet:
[model addr](https://github.com/BVLC/caffe/tree/master/models/bvlc_alexnet)
### Notes
Some of this code come from here: [caffe-tensorflow](https://github.com/ethereon/caffe-tensorflow)
#!/usr/bin/env python
import os
import sys
import numpy as np
import argparse
from kaffe import KaffeError, print_stderr
from kaffe.paddle import Transformer
def fatal_error(msg):
""" fatal error encounted
"""
print_stderr(msg)
exit(-1)
def validate_arguments(args):
""" validate args
"""
if (args.data_output_path is not None) and (args.caffemodel is None):
fatal_error('No input data path provided.')
if (args.caffemodel is not None) and (args.data_output_path is None):
fatal_error('No output data path provided.')
if (args.code_output_path is None) and (args.data_output_path is None):
fatal_error('No output path specified.')
def convert(def_path, caffemodel_path, data_output_path, code_output_path,
phase):
""" convert caffe model to tf/paddle models
"""
try:
transformer = Transformer(def_path, caffemodel_path, phase=phase)
print_stderr('Converting data...')
if caffemodel_path is not None:
data = transformer.transform_data()
print_stderr('Saving data...')
with open(data_output_path, 'wb') as data_out:
np.save(data_out, data)
if code_output_path:
print_stderr('Saving source...')
with open(code_output_path, 'wb') as src_out:
src_out.write(transformer.transform_source())
print_stderr('set env variable before using converted model '\
'if used custom_layers:')
custom_pk_path = os.path.dirname(os.path.abspath(__file__))
custom_pk_path = os.path.join(custom_pk_path, 'kaffe')
print_stderr('export CAFFE2FLUID_CUSTOM_LAYERS=%s' % (custom_pk_path))
print_stderr('Done.')
return 0
except KaffeError as err:
fatal_error('Error encountered: {}'.format(err))
return 1
def main():
""" main
"""
parser = argparse.ArgumentParser()
parser.add_argument('def_path', help='Model definition (.prototxt) path')
parser.add_argument('--caffemodel', help='Model data (.caffemodel) path')
parser.add_argument('--data-output-path', help='Converted data output path')
parser.add_argument(
'--code-output-path', help='Save generated source to this path')
parser.add_argument(
'-p',
'--phase',
default='test',
help='The phase to convert: test (default) or train')
args = parser.parse_args()
validate_arguments(args)
return convert(args.def_path, args.caffemodel, args.data_output_path,
args.code_output_path, args.phase)
if __name__ == '__main__':
ret = main()
sys.exit(ret)
A demo to show converting caffe models trained on 'imagenet' using caffe2fluid
---
# How to use
1. Prepare python environment
2. Download caffe model to "models.caffe/xxx" which contains "xxx.caffemodel" and "xxx.prototxt"
3. Convert the Caffe model to Fluid model
- generate fluid code and weight file
```python convert.py alexnet.prototxt \
--caffemodel alexnet.caffemodel \
--data-output-path alexnet.npy \
--code-output-path alexnet.py
```
- save weights as fluid model file
```
python alexnet.py alexnet.npy ./fluid
```
4. Do inference
```
python infer.py infer ./fluid data/65.jpeg
```
5. convert model and do inference together
```
bash ./tools/run.sh alexnet ./models.caffe/alexnet ./models/alexnet
```
* Assume the Caffe model is stored in '*./models.caffe/alexnet/alexnet.prototxt|caffemodel*'
* converted model will be stored as '*./models/alexnet/alexnet.py|npy*'
6. test the difference with caffe's results(need pycaffe installed)
```
bash ./tools/diff.sh resnet
```
* Make sure your caffemodel stored in '*./models.caffe/resnet*'
* The results will be stored in '*./results/resnet.paddle|caffe*'
#!/usr/bin/python
#
#a tool to compare tensors in two files or two directories
#
import sys
import os
def walk_dir(rootdir):
for subdir, dirs, files in os.walk(rootdir):
for file in files:
yield file
def calc_diff(f1, f2):
import numpy as np
d1 = np.load(f1)
d2 = np.load(f2)
#print d1.shape
#print d2.shape
#print d1[0, 0, 0:10, 0:10]
#print d2[0, 0, 0:10, 0:10]
d1 = d1.flatten()
d2 = d2.flatten()
d1_num = reduce(lambda x, y: x * y, d1.shape)
d2_num = reduce(lambda x, y: x * y, d2.shape)
if d1_num != d2_num:
print d1.shape
print d2.shape
assert (d1_num == d2_num), "their shape is not consistent"
try:
mask = np.abs(d1) >= np.abs(d2)
mask = mask.astype('int32')
df = np.abs(d1 - d2)
df = df / (1.0e-10 + np.abs(d1) * mask + np.abs(d2) * (1 - mask))
max_df = np.max(df)
sq_df = np.mean(df * df)
return max_df, sq_df
except Exception as e:
return 1.0, 1.0
def compare(path1, path2, no_exception):
def diff(f1, f2):
max_df, sq_df = calc_diff(f1, f2)
print('[max_df:%.4e, sq_df:%.4e] when compare %s <=> %s' %
(max_df, sq_df, os.path.basename(f1), os.path.basename(f2)))
if no_exception is False:
assert (max_df < 1e-5), \
'max_df is too large with value[%.6e]' % (max_df)
assert (sq_df < 1e-10), \
'sq_df is too large with value[%.6e]' % (sq_df)
if os.path.exists(path1) is False:
print('not found %s' % (path1))
return 1
elif os.path.exists(path2) is False:
print('not found %s' % (path2))
return 1
if path1.find('.npy') > 0 and path2.find('.npy') > 0:
diff(path1, path2)
return
for f in walk_dir(path2):
if f.find('.npy') < 0:
continue
f1 = os.path.join(path1, f)
f2 = os.path.join(path2, f)
diff(f1, f2)
print('all checking succeed to pass')
return 0
if __name__ == "__main__":
if len(sys.argv) == 1:
path1 = 'lenet.tf/results'
path2 = 'lenet.paddle/results'
elif len(sys.argv) >= 3:
path1 = sys.argv[1]
path2 = sys.argv[2]
if len(sys.argv) == 4:
no_exception = True
else:
no_exception = False
else:
print('usage:')
print(' %s [path1] [path2]' % (sys.argv[0]))
exit(1)
#print('compare inner result in %s %s' % (path1, path2))
exit(compare(path1, path2, no_exception))
#!/bin/env python
#function:
# a demo to show how to use the converted model genereated by caffe2fluid
#
#notes:
# only support imagenet data
import os
import sys
import inspect
import numpy as np
def import_fluid():
import paddle.fluid as fluid
return fluid
def load_data(imgfile, shape):
h, w = shape[1:]
from PIL import Image
im = Image.open(imgfile)
# The storage order of the loaded image is W(widht),
# H(height), C(channel). PaddlePaddle requires
# the CHW order, so transpose them.
im = im.resize((w, h), Image.ANTIALIAS)
im = np.array(im).astype(np.float32)
im = im.transpose((2, 0, 1)) # CHW
im = im[(2, 1, 0), :, :] # BGR
# The mean to be subtracted from each image.
# By default, the per-channel ImageNet mean.
mean = np.array([104., 117., 124.], dtype=np.float32)
mean = mean.reshape([3, 1, 1])
im = im - mean
return im.reshape([1] + shape)
def build_model(net_file, net_name):
print('build model with net_file[%s] and net_name[%s]' %
(net_file, net_name))
net_path = os.path.dirname(net_file)
module_name = os.path.splitext(os.path.basename(net_file))[0]
if net_path not in sys.path:
sys.path.insert(0, net_path)
try:
m = __import__(module_name, fromlist=[net_name])
MyNet = getattr(m, net_name)
except Exception as e:
print('failed to load module[%s.%s]' % (module_name, net_name))
print(e)
return None
fluid = import_fluid()
inputs_dict = MyNet.input_shapes()
input_name = inputs_dict.keys()[0]
input_shape = inputs_dict[input_name]
images = fluid.layers.data(
name=input_name, shape=input_shape, dtype='float32')
#label = fluid.layers.data(name='label', shape=[1], dtype='int64')
net = MyNet({input_name: images})
return net, inputs_dict
def dump_results(results, names, root):
if os.path.exists(root) is False:
os.mkdir(root)
for i in range(len(names)):
n = names[i]
res = results[i]
filename = os.path.join(root, n)
np.save(filename + '.npy', res)
def normalize_name(name_map):
return {
k.replace('/', '_'): v.replace('/', '_')
for k, v in name_map.items()
}
def rename_layer_name(names, net):
""" because the names of output layers from caffe maybe changed for 'INPLACE' operation,
and paddle's layers maybe fused, so we need to re-mapping their relationship for comparing
"""
#build a mapping from paddle's name to caffe's name
trace = getattr(net, 'name_trace', None)
cf_trace = trace['caffe']
real2cf = normalize_name(cf_trace['real2chg'])
pd_trace = trace['paddle']
pd2real = normalize_name(pd_trace['chg2real'])
pd_deleted = normalize_name(pd_trace['deleted'])
pd2cf_name = {}
for pd_name, real_name in pd2real.items():
if real_name in real2cf:
pd2cf_name[pd_name] = '%s.%s.%s.both_changed' \
% (real2cf[real_name], real_name, pd_name)
else:
pd2cf_name[pd_name] = '%s.%s.pd_changed' % (real_name, pd_name)
for pd_name, trace in pd_deleted.items():
assert pd_name not in pd2cf_name, "this name[%s] has already exist" % (
pd_name)
pd2cf_name[pd_name] = '%s.pd_deleted' % (pd_name)
for real_name, cf_name in real2cf.items():
if cf_name not in pd2cf_name:
pd2cf_name[cf_name] = '%s.cf_deleted' % (cf_name)
if real_name not in pd2cf_name:
pd2cf_name[real_name] = '%s.%s.cf_changed' % (cf_name, real_name)
ret = []
for name in names:
new_name = pd2cf_name[name] if name in pd2cf_name else name
print('remap paddle name[%s] to output name[%s]' % (name, new_name))
ret.append(new_name)
return ret
def load_model(exe, place, net_file, net_name, net_weight, debug):
""" load model using xxxnet.py and xxxnet.npy
"""
fluid = import_fluid()
#1, build model
net, input_map = build_model(net_file, net_name)
feed_names = input_map.keys()
feed_shapes = [v for k, v in input_map.items()]
prediction = net.get_output()
#2, load weights for this model
startup_program = fluid.default_startup_program()
exe.run(startup_program)
#place = fluid.CPUPlace()
if net_weight.find('.npy') > 0:
net.load(data_path=net_weight, exe=exe, place=place)
else:
raise ValueError('not found weight file')
#3, test this model
test_program = fluid.default_main_program().clone()
fetch_list_var = []
fetch_list_name = []
if debug is False:
fetch_list_var.append(prediction)
else:
for k, v in net.layers.items():
fetch_list_var.append(v)
fetch_list_name.append(k)
return {
'program': test_program,
'feed_names': feed_names,
'fetch_vars': fetch_list_var,
'fetch_names': fetch_list_name,
'feed_shapes': feed_shapes,
'net': net
}
def get_shape(fluid, program, name):
for var in program.list_vars():
if var.name == 'data':
return list(var.shape[1:])
raise ValueError('not found shape for input layer[%s], '
'you can specify by yourself' % (name))
def load_inference_model(dirname, exe):
""" load fluid's inference model
"""
fluid = import_fluid()
model_fn = 'model'
params_fn = 'params'
if os.path.exists(os.path.join(dirname, model_fn)) \
and os.path.exists(os.path.join(dirname, params_fn)):
program, feed_names, fetch_targets = fluid.io.load_inference_model(\
dirname, exe, model_fn, params_fn)
else:
raise ValueError('not found model files in direcotry[%s]' % (dirname))
#print fluid.global_scope().find_var(feed_names[0])
input_shape = get_shape(fluid, program, feed_names[0])
feed_shapes = [input_shape]
return program, feed_names, fetch_targets, feed_shapes
def infer(model_path, imgfile, net_file=None, net_name=None, debug=True):
""" do inference using a model which consist 'xxx.py' and 'xxx.npy'
"""
fluid = import_fluid()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
try:
ret = load_inference_model(model_path, exe)
program, feed_names, fetch_targets, feed_shapes = ret
debug = False
print('found a inference model for fluid')
except ValueError as e:
print('try to load model using net file and weight file')
net_weight = model_path
ret = load_model(exe, place, net_file, net_name, net_weight, debug)
program = ret['program']
feed_names = ret['feed_names']
fetch_targets = ret['fetch_vars']
fetch_list_name = ret['fetch_names']
feed_shapes = ret['feed_shapes']
net = ret['net']
input_name = feed_names[0]
input_shape = feed_shapes[0]
np_images = load_data(imgfile, input_shape)
results = exe.run(program=program,
feed={input_name: np_images},
fetch_list=fetch_targets)
if debug is True:
dump_path = 'results.paddle'
dump_names = rename_layer_name(fetch_list_name, net)
dump_results(results, dump_names, dump_path)
print('all result of layers dumped to [%s]' % (dump_path))
else:
result = results[0]
print('succeed infer with results[class:%d]' % (np.argmax(result)))
return 0
def caffe_infer(prototxt, caffemodel, datafile):
""" do inference using pycaffe for debug,
all intermediate results will be dumpped to 'results.caffe'
"""
import caffe
net = caffe.Net(prototxt, caffemodel, caffe.TEST)
input_layer = net.blobs.keys()[0]
print('got name of input layer is:%s' % (input_layer))
input_shape = list(net.blobs[input_layer].data.shape[1:])
if '.npy' in datafile:
np_images = np.load(datafile)
else:
np_images = load_data(datafile, input_shape)
inputs = {input_layer: np_images}
net.forward_all(**inputs)
results = []
names = []
for k, v in net.blobs.items():
k = k.replace('/', '_')
names.append(k)
results.append(v.data.copy())
dump_path = 'results.caffe'
dump_results(results, names, dump_path)
print('all result of layers dumped to [%s]' % (dump_path))
return 0
if __name__ == "__main__":
""" maybe more convenient to use 'run.sh' to call this tool
"""
net_file = 'models/resnet50/resnet50.py'
weight_file = 'models/resnet50/resnet50.npy'
datafile = 'data/65.jpeg'
net_name = 'ResNet50'
model_file = 'models/resnet50/fluid'
ret = None
if len(sys.argv) <= 2:
pass
elif sys.argv[1] == 'caffe':
if len(sys.argv) != 5:
print('usage:')
print('\tpython %s caffe [prototxt] [caffemodel] [datafile]' %
(sys.argv[0]))
sys.exit(1)
prototxt = sys.argv[2]
caffemodel = sys.argv[3]
datafile = sys.argv[4]
ret = caffe_infer(prototxt, caffemodel, datafile)
elif sys.argv[1] == 'infer':
if len(sys.argv) != 4:
print('usage:')
print('\tpython %s infer [fluid_model] [datafile]' % (sys.argv[0]))
sys.exit(1)
model_path = sys.argv[2]
datafile = sys.argv[3]
ret = infer(model_path, datafile)
elif sys.argv[1] == 'dump':
if len(sys.argv) != 6:
print('usage:')
print('\tpython %s dump [net_file] [weight_file] [datafile] [net_name]' \
% (sys.argv[0]))
print('\teg:python %s dump %s %s %s %s' % (sys.argv[0],\
net_file, weight_file, datafile, net_name))
sys.exit(1)
net_file = sys.argv[2]
weight_file = sys.argv[3]
datafile = sys.argv[4]
net_name = sys.argv[5]
ret = infer(weight_file, datafile, net_file, net_name)
if ret is None:
print('usage:')
print(' python %s [infer] [fluid_model] [imgfile]' % (sys.argv[0]))
print(' eg:python %s infer %s %s' % (sys.argv[0], model_file, datafile))
sys.exit(1)
sys.exit(ret)
#!/bin/bash
#
#function:
# a tool used to compare the results produced by paddle and caffe
#
if [[ $# -lt 2 ]];then
echo "usage:"
echo " bash $0 [model_name] [param_name] [caffe_name]"
exit 1
fi
model_name=$1
param_name=$2
paddle_file="./results/${model_name}.paddle/${param_name}.npy"
if [[ $# -eq 3 ]];then
caffe_file="./results/${model_name}.caffe/${3}.npy"
else
caffe_file="./results/${model_name}.caffe/${2}.npy"
fi
cmd="python ./compare.py $paddle_file $caffe_file"
echo $cmd
eval $cmd
#!/bin/bash
#function:
# a tool used to compare all layers' results
#
#set -x
if [[ $# -ne 1 ]];then
echo "usage:"
echo " bash $0 [model_name]"
echo " eg:bash $0 alexnet"
exit 1
fi
model_name=$1
prototxt="models.caffe/$model_name/${model_name}.prototxt"
cat $prototxt | grep name | perl -ne 'if(/^\s*name\s*:\s+\"([^\"]+)/){ print $1."\n";}' >.layer_names
final_layer=$(cat $prototxt | perl -ne 'if(/^\s*top\s*:\s+\"([^\"]+)/){ print $1."\n";}' | tail -n1)
ret=$(grep "^$final_layer$" .layer_names | wc -l)
if [[ $ret -eq 0 ]];then
echo $final_layer >>.layer_names
fi
for i in $(cat .layer_names);do
i=${i//\//_}
cf_npy="results/${model_name}.caffe/${i}.npy"
#pd_npy="results/${model_name}.paddle/${i}.npy"
#pd_npy=$(find results/${model_name}.paddle -iname "${i}*.npy" | head -n1)
pd_npy=$(find results/${model_name}.paddle -iname "${i}.*npy" | grep deleted -v | head -n1)
if [[ ! -e $cf_npy ]];then
echo "caffe's result not exist[$cf_npy]"
continue
fi
if [[ ! -e $pd_npy ]];then
echo "paddle's result not exist[$pd_npy]"
continue
fi
python compare.py $cf_npy $pd_npy no_exception
if [[ $? -eq 0 ]];then
echo "succeed to compare layer[$i]"
else
echo "failed to compare layer[$i]"
fi
done
#!/bin/bash
#
#function:
# a tool used to check the difference of models' results generated by caffe model and paddle model
#
#howto:
# bash diff.sh resnet50 #when this has been finished, you can get the difference in precision
#
#notes:
# 0, in order to infer using caffe, we need pycaffe installed
# 1, prepare your caffe model in 'models.caffe/', eg: 'model.caffe/resnet101/resnet101.[prototxt|caffemodel]'
# 2, converted paddle model will be in 'models'
# 3, results of layers will be stored in 'results/${model_name}.[paddle|caffe]'
# 4, only the last layer will be checked by default
model_name="resnet50"
results_root="results/"
if [[ -n $1 ]];then
if [ $1 = "-h" ];then
echo "usage:"
echo " bash $0 [model_name]"
echo " eg:bash $0 resnet50"
exit 0
fi
model_name=$1
fi
mkdir -p $results_root
prototxt="models.caffe/$model_name/${model_name}.prototxt"
caffemodel="models.caffe/${model_name}/${model_name}.caffemodel"
#1, dump layers' results from paddle
paddle_results="$results_root/${model_name}.paddle"
rm -rf $paddle_results
rm -rf "results.paddle"
bash ./tools/run.sh $model_name ./models.caffe/$model_name ./models/$model_name
if [[ $? -ne 0 ]] || [[ ! -e "results.paddle" ]];then
echo "not found paddle's results, maybe failed to convert"
exit 1
fi
mv results.paddle $paddle_results
#2, dump layers' results from caffe
caffe_results="$results_root/${model_name}.caffe"
rm -rf $caffe_results
rm -rf "results.caffe"
PYTHON=`which cfpython`
if [[ -z $PYTHON ]];then
PYTHON=`which python`
fi
$PYTHON ./infer.py caffe $prototxt $caffemodel $paddle_results/data.npy
if [[ $? -ne 0 ]] || [[ ! -e "results.caffe" ]];then
echo "not found caffe's results, maybe failed to do inference with caffe"
exit 1
fi
mv results.caffe $caffe_results
#3, extract layer names
cat $prototxt | grep name | perl -ne 'if(/^\s*name\s*:\s+\"([^\"]+)/){ print $1."\n";}' >.layer_names
final_layer=$(cat $prototxt | perl -ne 'if(/^\s*top\s*:\s+\"([^\"]+)/){ print $1."\n";}' | tail -n1)
ret=$(grep "^$final_layer$" .layer_names | wc -l)
if [[ $ret -eq 0 ]];then
echo $final_layer >>.layer_names
fi
#4, compare one by one
#for i in $(cat .layer_names);do
for i in $(cat .layer_names | tail -n1);do
i=${i//\//_}
echo "process $i"
pd_npy=$(find $paddle_results/ -iname "${i}.*npy" | grep deleted -v | head -n1)
#pd_npy="$paddle_results/${i}.npy"
if [[ -f $pd_npy ]];then
$PYTHON compare.py $caffe_results/${i}.npy $pd_npy
else
echo "not found npy file[${i}.*npy] for layer[$i]"
exit 1
fi
done
#!/bin/bash
#function:
# a tool used to:
# 1, convert a caffe model
# 2, do inference(only in fluid) using this model
#
#usage:
# cd caffe2fluid/examples/imagenet && bash run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50
#
#set -x
if [[ $# -lt 3 ]];then
echo "usage:"
echo " bash $0 [model_name] [cf_model_path] [pd_model_path] [only_convert]"
echo " eg: bash $0 resnet50 ./models.caffe/resnet50 ./models/resnet50"
exit 1
else
model_name=$1
cf_model_path=$2
pd_model_path=$3
only_convert=$4
fi
proto_file=$cf_model_path/${model_name}.prototxt
caffemodel_file=$cf_model_path/${model_name}.caffemodel
weight_file=$pd_model_path/${model_name}.npy
net_file=$pd_model_path/${model_name}.py
if [[ ! -e $proto_file ]];then
echo "not found prototxt[$proto_file]"
exit 1
fi
if [[ ! -e $caffemodel_file ]];then
echo "not found caffemodel[$caffemodel_file]"
exit 1
fi
if [[ ! -e $pd_model_path ]];then
mkdir $pd_model_path
fi
PYTHON=`which cfpython`
if [[ -z $PYTHON ]];then
PYTHON=`which python`
fi
$PYTHON ../../convert.py \
$proto_file \
--caffemodel $caffemodel_file \
--data-output-path $weight_file\
--code-output-path $net_file
ret=$?
if [[ $ret -ne 0 ]];then
echo "failed to convert caffe model[$cf_model_path]"
exit $ret
else
echo "succeed to convert caffe model[$cf_model_path] to fluid model[$pd_model_path]"
fi
if [[ -z $only_convert ]];then
PYTHON=`which pdpython`
if [[ -z $PYTHON ]];then
PYTHON=`which python`
fi
imgfile="data/65.jpeg"
#FIX ME:
# only look the first line in prototxt file for the name of this network, maybe not correct
net_name=`grep "name" $proto_file | head -n1 | perl -ne 'if(/^name\s*:\s*\"([^\"]+)\"/){ print $1."\n";}'`
if [[ -z $net_name ]];then
net_name="MyNet"
fi
cmd="$PYTHON ./infer.py dump $net_file $weight_file $imgfile $net_name"
echo $cmd
eval $cmd
ret=$?
fi
exit $ret
#!/bin/bash
#
#script to test all models
#
models="alexnet vgg16 googlenet resnet152 resnet101 resnet50"
for i in $models;do
echo "begin to process $i"
bash ./tools/diff.sh $i 2>&1
echo "finished to process $i with ret[$?]"
done
a demo to show converting caffe model on 'mnist' using caffe2fluid
---
# How to use
1. prepare python environment
2. download caffe model to "models.caffe/lenet" which contains "lenet.caffemodel" and "lenet.prototxt"
3. run the tool
eg: bash ./run.sh lenet ./models.caffe/lenet ./models/lenet
#!/bin/env python
#function:
# demo to show how to use converted model using caffe2fluid
#
import sys
import os
import numpy as np
import paddle.fluid as fluid
import paddle
def test_model(exe, test_program, fetch_list, test_reader, feeder):
acc_set = []
for data in test_reader():
acc_np, pred = exe.run(program=test_program,
feed=feeder.feed(data),
fetch_list=fetch_list)
acc_set.append(float(acc_np))
acc_val = np.array(acc_set).mean()
return float(acc_val)
def evaluate(net_file, model_file):
""" main
"""
#1, build model
net_path = os.path.dirname(net_file)
if net_path not in sys.path:
sys.path.insert(0, net_path)
from lenet import LeNet as MyNet
#1, define network topology
images = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
net = MyNet({'data': images})
prediction = net.layers['prob']
acc = fluid.layers.accuracy(input=prediction, label=label)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
#2, load weights
if model_file.find('.npy') > 0:
net.load(data_path=model_file, exe=exe, place=place)
else:
net.load(data_path=model_file, exe=exe)
#3, test this model
test_program = fluid.default_main_program().clone()
test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
fetch_list = [acc, prediction]
print('go to test model using test set')
acc_val = test_model(exe, test_program, \
fetch_list, test_reader, feeder)
print('test accuracy is [%.4f], expected value[0.919]' % (acc_val))
if __name__ == "__main__":
net_file = 'models/lenet/lenet.py'
weight_file = 'models/lenet/lenet.npy'
argc = len(sys.argv)
if argc == 3:
net_file = sys.argv[1]
weight_file = sys.argv[2]
elif argc > 1:
print('usage:')
print('\tpython %s [net_file] [weight_file]' % (sys.argv[0]))
print('\teg:python %s %s %s %s' % (sys.argv[0], net_file, weight_file))
sys.exit(1)
evaluate(net_file, weight_file)
#!/bin/bash
#function:
# a tool used to:
# 1, convert a caffe model
# 2, do inference using this model
#
#usage:
# bash run.sh lenet ./models.caffe/lenet ./models/lenet
#
#set -x
if [[ $# -lt 3 ]];then
echo "usage:"
echo " bash $0 [model_name] [cf_model_path] [pd_model_path] [only_convert]"
echo " eg: bash $0 lenet ./models.caffe/lenet ./models/lenet"
exit 1
else
model_name=$1
cf_model_path=$2
pd_model_path=$3
no_eval=$4
fi
proto_file=$cf_model_path/${model_name}.prototxt
caffemodel_file=$cf_model_path/${model_name}.caffemodel
weight_file=$pd_model_path/${model_name}.npy
net_file=$pd_model_path/${model_name}.py
if [[ ! -e $proto_file ]];then
echo "not found prototxt[$proto_file]"
exit 1
fi
if [[ ! -e $caffemodel_file ]];then
echo "not found caffemodel[$caffemodel_file]"
exit 1
fi
if [[ ! -e $pd_model_path ]];then
mkdir $pd_model_path
fi
PYTHON=`which cfpython`
if [[ -z $PYTHON ]];then
PYTHON=`which python`
fi
$PYTHON ../../convert.py \
$proto_file \
--caffemodel $caffemodel_file \
--data-output-path $weight_file\
--code-output-path $net_file
ret=$?
if [[ $ret -ne 0 ]];then
echo "failed to convert caffe model[$cf_model_path]"
exit $ret
else
echo "succeed to convert caffe model[$cf_model_path] to fluid model[$pd_model_path]"
fi
if [[ -z $only_convert ]];then
PYTHON=`which pdpython`
if [[ -z $PYTHON ]];then
PYTHON=`which python`
fi
net_name=`grep "name" $proto_file | head -n1 | perl -ne 'if(/\"([^\"]+)\"/){ print $1."\n";}'`
if [[ $net_name != "LeNet" ]];then
echo "only support LeNet"
exit 1
fi
$PYTHON ./evaluate.py $net_file $weight_file
ret=$?
fi
exit $ret
from .graph import GraphBuilder, NodeMapper
from .errors import KaffeError, print_stderr
import os
from . import paddle
from .resolver import get_caffe_resolver, has_pycaffe
import os
import sys
SHARED_CAFFE_RESOLVER = None
def import_caffepb():
p = os.path.realpath(__file__)
p = os.path.dirname(p)
p = os.path.join(p, '../../proto')
sys.path.insert(0, p)
import caffe_pb2
return caffe_pb2
class CaffeResolver(object):
def __init__(self):
self.import_caffe()
def import_caffe(self):
self.caffe = None
try:
# Try to import PyCaffe first
import caffe
self.caffe = caffe
except ImportError:
# Fall back to the protobuf implementation
self.caffepb = import_caffepb()
show_fallback_warning()
if self.caffe:
# Use the protobuf code from the imported distribution.
# This way, Caffe variants with custom layers will work.
self.caffepb = self.caffe.proto.caffe_pb2
self.NetParameter = self.caffepb.NetParameter
def has_pycaffe(self):
return self.caffe is not None
def get_caffe_resolver():
global SHARED_CAFFE_RESOLVER
if SHARED_CAFFE_RESOLVER is None:
SHARED_CAFFE_RESOLVER = CaffeResolver()
return SHARED_CAFFE_RESOLVER
def has_pycaffe():
return get_caffe_resolver().has_pycaffe()
def show_fallback_warning():
msg = '''
------------------------------------------------------------
WARNING: PyCaffe not found!
Falling back to a pure protocol buffer implementation.
* Conversions will be drastically slower.
------------------------------------------------------------
'''
sys.stderr.write(msg)
"""
"""
from .register import get_registered_layers
#custom layer import begins
import axpy
import flatten
import argmax
import reshape
import roipooling
import priorbox
import permute
import detection_out
import normalize
import select
import crop
import power
import reduction
#custom layer import ends
custom_layers = get_registered_layers()
def set_args(f, params, node=None):
""" set args for function 'f' using the parameters in node.layer.parameters
Args:
f (function): a python function object
params (object): a object contains attributes needed by f's arguments
Returns:
arg_names (list): a list of argument names
kwargs (dict): a dict contains needed arguments
"""
from ..protobuf_to_dict import protobuf_to_dict
argc = f.__code__.co_argcount
arg_list = f.__code__.co_varnames[0:argc]
kwargs = {}
for arg_name in arg_list:
if arg_name in params:
kwargs[arg_name] = params[arg_name]
if node is not None and len(node.metadata):
kwargs.update(node.metadata)
return arg_list, kwargs
def has_layer(kind):
""" test whether this layer exists in custom layer
"""
return kind in custom_layers
def compute_output_shape(kind, node):
assert kind in custom_layers, "layer[%s] not exist in custom layers" % (
kind)
shape_func = custom_layers[kind]['shape']
parents = node.parents
inputs = [list(p.output_shape) for p in parents]
arg_names, kwargs = set_args(shape_func, node.params)
if len(inputs) == 1:
inputs = inputs[0]
return shape_func(inputs, **kwargs)
def make_node(template, kind, node):
""" make a PaddleNode for custom layer which means construct
a piece of code to define a layer implemented in 'custom_layers'
Args:
@template (PaddleNode): a factory to new a instance of PaddleNode
@kind (str): type of custom layer
@node (graph.Node): a layer in the net
Returns:
instance of PaddleNode
"""
assert kind in custom_layers, "layer[%s] not exist in custom layers" % (
kind)
layer_func = custom_layers[kind]['layer']
#construct arguments needed by custom layer function from node's parameters
arg_names, kwargs = set_args(layer_func, node.params, node)
return template('custom_layer', kind, **kwargs)
def make_custom_layer(kind, inputs, name, *args, **kwargs):
""" execute a custom layer which is implemented by users
Args:
@kind (str): type name of this layer
@inputs (vars): variable list created by fluid
@namme (str): name for this layer
@args (tuple): other positional arguments
@kwargs (dict): other kv arguments
Returns:
output (var): output variable for this layer
"""
assert kind in custom_layers, "layer[%s] not exist in custom layers" % (
kind)
layer_func = custom_layers[kind]['layer']
return layer_func(inputs, name, *args, **kwargs)
""" a custom layer for 'argmax', maybe we should implement this in standard way.
more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/argmax.html
"""
from .register import register
def import_fluid():
import paddle.fluid as fluid
return fluid
def argmax_shape(input_shape, out_max_val=False, top_k=1, axis=-1):
""" calculate the output shape of this layer using input shape
Args:
@input_shape (list of num): a list of number which represents the input shape
@out_max_val (bool): parameter from caffe's ArgMax layer
@top_k (int): parameter from caffe's ArgMax layer
@axis (int): parameter from caffe's ArgMax layer
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
input_shape = list(input_shape)
if axis < 0:
axis += len(input_shape)
assert (axis + 1 == len(input_shape)
), 'only can be applied on the last dimension[axis:%d, %s] now,'\
'make sure you have set axis param in xxx.prototxt file' \
% (axis, str(input_shape))
output_shape = input_shape
output_shape[-1] = top_k
if out_max_val is True:
output_shape[-1] *= 2
return output_shape
def argmax_layer(input, name, out_max_val=False, top_k=1, axis=-1):
""" build a layer of type 'ArgMax' using fluid
Args:
@input (variable): input fluid variable for this layer
@name (str): name for this layer
@out_max_val (bool): parameter from caffe's ArgMax layer
@top_k (int): parameter from caffe's ArgMax layer
@axis (int): parameter from caffe's ArgMax layer
Returns:
output (variable): output variable for this layer
"""
fluid = import_fluid()
if axis < 0:
axis += len(input.shape)
if out_max_val is True:
topk_var, index_var = fluid.layers.topk(input=input, k=top_k)
index_var = fluid.layers.cast(index_var, dtype=topk_var.dtype)
output = fluid.layers.concat(
[index_var, topk_var], axis=axis, name=name)
else:
topk_var, index_var = fluid.layers.topk(input=input, k=top_k, name=name)
output = index_var
return output
register(kind='ArgMax', shape=argmax_shape, layer=argmax_layer)
""" A custom layer for 'axpy' which receives 3 tensors and output 1 tensor.
the function performed is:(the mupltiplication and add are elementewise)
output = inputs[0] * inputs[1] + inputs[2]
"""
from .register import register
def axpy_shape(input_shapes):
""" calculate the output shape of this layer using input shapes
Args:
@input_shapes (list of tuples): a list of input shapes
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
assert len(input_shapes) == 3, "not valid input shape for axpy layer"
assert len(input_shapes[0]) == len(input_shapes[1]), 'should have same dims'
output_shape = input_shapes[1]
assert (input_shapes[2] == output_shape),\
"shape not consistent for axpy[%s <--> %s]" \
% (str(output_shape), str(input_shapes[2]))
return output_shape
def axpy_layer(inputs, name):
""" build a layer of type 'Axpy' using fluid
Args:
@inputs (list of variables): input fluid variables for this layer
@name (str): name for this layer
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
assert len(inputs) == 3, "invalid inputs for axpy[%s]" % (name)
alpha = inputs[0]
x = inputs[1]
y = inputs[2]
output = fluid.layers.elementwise_mul(x, alpha, axis=0)
output = fluid.layers.elementwise_add(output, y, name=name)
return output
register(kind='Axpy', shape=axpy_shape, layer=axpy_layer)
""" a custom layer for 'crop', maybe we should implement this in standard way.
more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/crop.html
"""
from .register import register
def crop_shape(input_shape, shape=None):
""" calculate the output shape of this layer using input shape
Args:
@input_shape (num | list of num): a list of number or num which represents the input shape
@shape (list of integer): the shape of output
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
if isinstance(input_shape, list):
assert len(input_shape) == 2, "the number of crop's inputs must be 2"
return input_shape[1]
elif not shape is None:
assert len(shape) == len(
input_shape.shape), "input_shape is diff with output_shape"
return shape
else:
raise Exception, "crop_shape input error"
return None
def crop_layer(input, name, shape=None, axis=2, offset=None):
""" build a layer of type 'Crop' using fluid
Args:
@input (variables | list of variables): input fluid variable for this layer
@shape (list of integer): the shape of output
@name (str): name for this layer
@axis (integer): parameter from caffe's Crop layer
@offset (Variable|list/tuple of integer|None): parameter from caffe's Crop layer
Returns:
output (variable): output variable for this layer
"""
input_shape = None
output_shape = None
input_tensor = None
if isinstance(input, list):
assert len(input) == 2, "the number of crop's inputs must be 2"
input_shape = input[0].shape
output_shape = input[1].shape
input_tensor = input[0]
elif not shape is None:
assert len(shape) == len(
input.shape), "input_shape is diff with output_shape"
input_shape = input.shape
output_shape = shape
input_tensor = input
else:
raise Exception, "crop_layer input error"
assert len(output_shape) == len(
input_shape), "input_shape is diff with output_shape"
if axis < 0:
axis += len(input_shape)
if offset is not None:
assert (len(input_shape) - axis
) == len(offset), "invalid offset[%s] in crop layer" % (
str(offset))
offset = [0] * axis + offset
import paddle.fluid as fluid
output = fluid.layers.crop(
input_tensor, shape=output_shape, offsets=offset, name=name)
return output
register(kind='Crop', shape=crop_shape, layer=crop_layer)
""" A custom layer for 'detectionout' used in 'SSD' model to produce outputs
Note: Since Paddle's implementation of 'detectionout' applied 'flatten' and 'softmax' ops on the input of 'conf',
while Caffe's implementation do not.
"""
from .register import register
def detectionoutput_shape(input_shape):
""" the output shape of this layer is dynamic and not determined by 'input_shape'
Args:
@input_shape (list of int): input shape
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
output_shape = [-1, 6]
return output_shape
def detectionoutput_layer(inputs,
name,
background_label=0,
share_location=True,
nms_param=None,
keep_top_k=100,
confidence_threshold=0.1):
""" build a layer of type 'detectionout' using fluid
Args:
@inputs (list of variables): input fluid variables for this layer
@name (str): name for this layer
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
if nms_param is None:
nms_param = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
mbox_conf_flatten = inputs[1]
mbox_priorbox = inputs[2]
mbox_priorbox_list = fluid.layers.split(mbox_priorbox, 2, dim=1)
pb = mbox_priorbox_list[0]
pbv = mbox_priorbox_list[1]
pb = fluid.layers.reshape(x=pb, shape=[-1, 4])
pbv = fluid.layers.reshape(x=pbv, shape=[-1, 4])
mbox_loc = inputs[0]
mbox_loc = fluid.layers.reshape(
x=mbox_loc, shape=[-1, mbox_conf_flatten.shape[1], 4])
default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
fields = ['eta', 'top_k', 'nms_threshold']
for f in default.keys():
if not nms_param.has_key(f):
nms_param[f] = default[f]
nmsed_outs = fluid.layers.detection_output(
scores=mbox_conf_flatten,
loc=mbox_loc,
prior_box=pb,
prior_box_var=pbv,
background_label=background_label,
nms_threshold=nms_param["nms_threshold"],
nms_top_k=nms_param["top_k"],
keep_top_k=keep_top_k,
score_threshold=confidence_threshold,
nms_eta=nms_param["eta"])
return nmsed_outs
register(
kind='DetectionOutput',
shape=detectionoutput_shape,
layer=detectionoutput_layer)
""" a custom layer for 'flatten', maybe we should implement this in standard way.
more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/flatten.html
"""
from .register import register
def flatten_shape(input_shape, axis=1, end_axis=-1):
""" calculate the output shape of this layer using input shape
Args:
@input_shape (list of num): a list of number which represents the input shape
@axis (int): parameter from caffe's Flatten layer
@end_axis (int): parameter from caffe's Flatten layer
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
start_axis = axis
end_axis = end_axis
input_shape = list(input_shape)
if start_axis < 0:
start_axis += len(input_shape)
if end_axis < 0:
end_axis += len(input_shape) + 1
assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\
% (start_axis, end_axis)
output_shape = input_shape[0:start_axis]
flat_sz = reduce(lambda a, b: a * b, input_shape[start_axis:end_axis])
output_shape += [flat_sz]
output_shape += input_shape[end_axis:-1]
return output_shape
def flatten_layer(input, name, axis=1, end_axis=-1):
""" build a layer of type 'Flatten' using fluid
Args:
@input (variable): input fluid variable for this layer
@name (str): name for this layer
@axis (int): parameter from caffe's Flatten layer
@end_axis (int): parameter from caffe's Flatten layer
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
input_shape = list(input.shape)
if input_shape[0] == -1:
input_shape[0] = 1
output_shape = flatten_shape(input_shape, axis=axis, end_axis=end_axis)
output_shape[0] = -1
else:
output_shape = flatten_shape(input_shape, axis=axis, end_axis=end_axis)
output = fluid.layers.reshape(input, shape=output_shape, name=name)
return output
register(kind='Flatten', shape=flatten_shape, layer=flatten_layer)
""" A custom layer for 'normalize' op
"""
from .register import register
def normalize_shape(input_shape,
across_spatial=True,
scale_filler=True,
eps=1e-10):
""" calculate the output shape of this layer using input shapes
Args:
@input_shape (list of tuples): input shape
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
output_shape = input_shape
return output_shape
def normalize_layer(input,
name,
across_spatial=True,
scale_filler=True,
channel_shared=False,
eps=1e-10):
""" build a layer of type 'normalize' using fluid
Args:
@inputs (list of variables): input fluid variables for this layer
@name (str): name for this layer
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
param_prefix = name.split('.')[0]
assert across_spatial == False, "Only support across_spatial == False for Normalize[%s]" % (
name)
l2_norm = fluid.layers.l2_normalize(input, axis=1) # l2 norm along channel
shape = [1] if channel_shared else [input.shape[1]]
scale_attr = fluid.ParamAttr(name=param_prefix + '_scale')
scale_param = fluid.layers.create_parameter(
shape=shape, dtype=input.dtype, name=name, attr=scale_attr)
out = fluid.layers.elementwise_mul(
x=l2_norm, y=scale_param, axis=-1 if channel_shared else 1)
return out
register(kind='Normalize', shape=normalize_shape, layer=normalize_layer)
""" A custom layer for 'Permute' which is equivalent to transpose in paddle
"""
from .register import register
def permute_shape(input_shape, order):
""" calculate the output shape of this layer using input shapes
Args:
@input_shape (list of numbers): input shape
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
output_shape = []
for ii in order:
assert ii < len(input_shape), "invalid order for permute[%s]" % (name)
output_shape.append(input_shape[ii])
return output_shape
def permute_layer(input, name, order):
""" build a layer of type 'permute' using fluid
Args:
@input (input variable): input fluid variables for this layer
@name (str): name for this layer
@order (list of int): order to permute the dims
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
output = fluid.layers.transpose(input, order, name=name)
return output
register(kind='Permute', shape=permute_shape, layer=permute_layer)
""" a custom layer for 'power', maybe we should implement this in standard way.
more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/power.html
"""
from .register import register
def power_shape(input_shape, shape=None):
""" calculate the output shape of this layer using input shape
Args:
@input_shape (list of num): a list of number which represents the input shape
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
return input_shape
def power_layer(input, name, power=1.0, scale=1.0, shift=0.0):
""" build a layer of type 'Power' using fluid
Args:
@input (variables): input fluid variable for this layer
@name (str): name for this layer
@power (float): parameter from caffe's Power layer
@scale (float): parameter from caffe's Power layer
@shift (float): parameter from caffe's Power layer
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
scale_out = fluid.layers.scale(
input, scale=scale, bias=shift, bias_after_scale=True)
output = fluid.layers.pow(scale_out, factor=power)
return output
register(kind='Power', shape=power_shape, layer=power_layer)
""" A custom layer for 'priorbox' which is used in ssd to generate prior box info
Since the order of prior box is different between caffe and paddle,
we use 'slice' and 'concate' ops to align them.
"""
from .register import register
def priorbox_shape(input_shapes, min_size, max_size=None, aspect_ratio=None):
""" calculate the output shape of this layer using input shapes
Args:
@input_shapes (list of tuples): a list of input shapes
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
assert len(input_shapes) == 2, "invalid inputs for Priorbox[%s]" % (name)
fc_shape = input_shapes[0]
N = 1
if not max_size == None:
N += 1
if not aspect_ratio == None:
N += 2 * len(aspect_ratio)
N_bbx = fc_shape[2] * fc_shape[3] * N
output_shape = [1, 2, 4 * N_bbx]
return output_shape
def priorbox_layer(inputs,
name,
min_size,
max_size=None,
aspect_ratio=None,
variance=[0.1, 0.1, 0.2, 0.2],
flip=False,
clip=False,
step=0.0,
offset=0.5):
""" build a layer of type 'Priorbox' using fluid
Args:
@inputs (list of variables): input fluid variables for this layer
@name (str): name for this layer
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
assert len(inputs) == 2, "invalid inputs for Priorbox[%s]" % (name)
input = inputs[0]
image = inputs[1]
steps = tuple(step) if type(step) is list or type(step) is tuple else (step,
step)
box, variance_ = fluid.layers.prior_box(
input,
image,
min_size,
max_size,
aspect_ratio,
variance,
flip,
clip,
steps,
offset,
min_max_aspect_ratios_order=True)
"""
#adjust layout when the output is not consistent with caffe's
feat_shape = list(input.shape)
H = feat_shape[2]
W = feat_shape[3]
box_tmp = fluid.layers.reshape(box, [H, W, -1, 4])
nb_prior_bbx = int(box_tmp.shape[2])
tensor_list = fluid.layers.split(box_tmp, nb_prior_bbx, 2)
#TODO:
# current implementation for this layer is not efficient
# and we should fix this bug in future when Paddle support the same prior-box layout with Caffe
index_list = [0]
index_list = index_list * nb_prior_bbx
index_offset = 0
if max_size is not None:
index_list[1] = -1
index_offset = 1
for ii in xrange(2 * len(aspect_ratio)):
index_list[ii + 1 + index_offset] = ii + 1
tensor_list_gathered = [tensor_list[ii] for ii in index_list]
caffe_prior_bbx = fluid.layers.concat(tensor_list_gathered, axis=2)
box = fluid.layers.reshape(caffe_prior_bbx, [1, 1, -1])
"""
box = fluid.layers.reshape(box, [1, 1, -1])
variance_ = fluid.layers.reshape(variance_, [1, 1, -1])
output = fluid.layers.concat([box, variance_], axis=1)
return output
register(kind='PriorBox', shape=priorbox_shape, layer=priorbox_layer)
""" a custom layer for 'crop', maybe we should implement this in standard way.
more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/reduction.html
"""
from .register import register
def reduction_shape(input_shape, axis=0):
""" calculate the output shape of this layer using input shape
Args:
@input_shape (list of num): a list of number which represents the input shape
@axis (int): parameter from caffe's reduction layer
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
if axis < 0:
axis += len(input_shape) + 1
assert axis <= len(input_shape), 'invalid axis[%d] error' % (axis)
return input_shape[0:axis]
def reduction_layer(input, name, axis=0, operation=1, coeff=1.0):
""" build a layer of type 'Crop' using fluid
Args:
@input (variable): input fluid variable for this layer
@name (str): name for this layer
@axis (int): parameter from caffe's reduction layer
@operation (int): parameter from caffe's reduction layer
@coeff (float): parameter from caffe's reduction layer
Returns:
output (variable): output variable for this layer
"""
assert operation >= 1 and operation <= 4, "reduction reduction [%s] error" % (
operation)
input_len = len(input.shape)
if axis < 0:
axis += input_len + 1
dim = range(input_len)
import paddle.fluid as fluid
if operation == 1: ## operation = SUM
output = fluid.layers.reduce_sum(
input, dim=dim[axis:], keep_dim=False, name=name)
elif operation == 2: ## operation = ASUM
absout = fluid.layers.abs(input)
output = fluid.layers.reduce_sum(
absout, dim=dim[axis:], keep_dim=False, name=name)
elif operation == 3: ## operation = SUMSQ
powout = fluid.layers.pow(x=input, factor=2.0)
output = fluid.layers.reduce_sum(
powout, dim=dim[axis:], keep_dim=False, name=name)
else: ## operation = MEAN
output = fluid.layers.reduce_mean(
input, dim=dim[axis:], keep_dim=False, name=name)
mulout = fluid.layers.scale(x=output, scale=coeff)
return mulout
register(kind='Reduction', shape=reduction_shape, layer=reduction_layer)
""" this module provides 'register' for registering customized layers
"""
g_custom_layers = {}
def register(kind, shape, layer):
""" register a custom layer or a list of custom layers
Args:
@kind (str or list): type name of the layer
@shape (function): a function to generate the shape of layer's output
@layer (function): a function to generate the shape of layer's output
Returns:
None
"""
assert type(shape).__name__ == 'function', 'shape should be a function'
assert type(layer).__name__ == 'function', 'layer should be a function'
if type(kind) is str:
kind = [kind]
else:
assert type(
kind) is list, 'invalid param "kind" for register, not a list or str'
for k in kind:
assert type(
k) is str, 'invalid param "kind" for register, not a list of str'
assert k not in g_custom_layers, 'this type[%s] has already been registered' % (
k)
print('register layer[%s]' % (k))
g_custom_layers[k] = {'shape': shape, 'layer': layer}
def get_registered_layers():
return g_custom_layers
""" a custom layer for 'reshape', maybe we should implement this in standard way.
more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/reshape.html
"""
from .register import register
def import_fluid():
import paddle.fluid as fluid
return fluid
def reshape_shape(input_sp, shape, axis=0, num_axes=-1):
""" calculate the output shape of this layer using input shape
Args:
@input_shape (list of num): a list of number which represents the input shape
@shape (object): parameter from caffe's Reshape layer
@axis (int): parameter from caffe's Reshape layer
@num_axes(int): parameter from caffe's Reshape layer
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
def count(num_list):
return reduce(lambda a, b: a * b, num_list)
input_shape = list(input_sp)
input_count = count(input_shape)
input_num_axes = len(input_shape)
input_start_axis = axis
start_axis = input_start_axis if input_start_axis >= 0 \
else input_num_axes + input_start_axis + 1
assert start_axis >= 0, "[Reshape]axis %d out of range" % (input_start_axis)
assert start_axis <= input_num_axes, "[Reshape]axis %d out of range for %d-D input data"\
% (input_start_axis, input_num_axes)
assert num_axes >= -1, "[Reshape]num_axes must be >= 0, or -1 for all"
end_axis = input_num_axes if num_axes == -1 else start_axis + num_axes
assert end_axis <= input_num_axes, "end_axis[%d] = axis[%d] + num_axes[%d] is out of range"\
% (end_axis, start_axis, num_axes)
num_axes_replaced = end_axis - start_axis
num_axes_retained = input_num_axes - num_axes_replaced
num_new_axes = len(shape['dim'])
output_shape = []
for i in range(start_axis):
output_shape.append(input_shape[i])
for i in range(num_new_axes):
output_shape.append(shape['dim'][i])
for i in range(end_axis, input_num_axes):
output_shape.append(input_shape[i])
assert len(output_shape) == num_axes_retained + num_new_axes,\
"[Reshape]invalid dims of output shape[%s]" % (str(output_shape))
inferred_axis = -1
copy_axes = []
constant_count = 1
for i in range(num_new_axes):
top_dim = shape['dim'][i]
if top_dim == 0:
copy_axes.append(i)
copy_axis_index = start_axis + i
output_shape[copy_axis_index] = input_shape[copy_axis_index]
elif top_dim == -1:
assert inferred_axis == -1, "[Reshape]new shape contains multiple -1 dims"
inferred_axis = i
else:
constant_count *= top_dim
if inferred_axis >= 0:
explicit_count = constant_count
l = input_shape[0:start_axis]
if len(l) > 0:
explicit_count *= count(l)
l = input_shape[end_axis:]
if len(l) > 0:
explicit_count *= count(l)
for i in range(len(copy_axes)):
explicit_count *= output_shape[start_axis + copy_axes[i]]
assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\
"must be divisible by product of the specified dimensions[%d] "\
% (input_count, explicit_count)
output_shape[start_axis + inferred_axis] = input_count / explicit_count
output_count = count(output_shape)
assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % (
output_count, input_count)
return output_shape
def reshape_layer(input, name, shape, axis=0, num_axes=-1):
""" build a layer of type 'Flatten' using fluid
Args:
@input (variable): input fluid variable for this layer
@name (str): name for this layer
@shape (object): parameter from caffe's Reshape layer
@axis (int): parameter from caffe's Reshape layer
@num_axes(int): parameter from caffe's Reshape layer
Returns:
output (variable): output variable for this layer
"""
fluid = import_fluid()
input_shape = list(input.shape)
if input_shape[0] == -1:
input_shape[0] = 1
output_shape = reshape_shape(input_shape, shape, axis, num_axes)
output_shape[0] = -1
else:
output_shape = reshape_shape(input_shape, shape, axis, num_axes)
output = fluid.layers.reshape(input, shape=output_shape, name=name)
return output
register(kind='Reshape', shape=reshape_shape, layer=reshape_layer)
""" a custom layer for 'ROIPooling', maybe we should implement this in standard way.
more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/ROIPooling.html
"""
from .register import register
def roipooling_shape(input_shapes, pooled_h, pooled_w, spatial_scale):
""" calculate the output shape of this layer using input shape
Args:
@input_shape (list of num): a list of number which represents the input shape
@out_max_val (bool): parameter from caffe's ROIPooling layer
@top_k (int): parameter from caffe's ROIPooling layer
@axis (int): parameter from caffe's ROIPooling layer
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
assert len(input_shapes) == 2, "not valid input shape for roipooling layer"
base_fea_shape = input_shapes[0]
rois_shape = input_shapes[1]
output_shape = base_fea_shape
output_shape[0] = rois_shape[0]
output_shape[2] = pooled_h
output_shape[3] = pooled_w
return output_shape
def roipooling_layer(inputs, name, pooled_h, pooled_w, spatial_scale):
""" build a layer of type 'ROIPooling' using fluid
Args:
@input (variable): input fluid variable for this layer
@name (str): name for this layer
@out_max_val (bool): parameter from caffe's ROIPooling layer
@top_k (int): parameter from caffe's ROIPooling layer
@axis (int): parameter from caffe's ROIPooling layer
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
assert len(inputs) == 2, "not valid input shape for roipooling layer"
base_fea = inputs[0]
rois = inputs[1][:, 1:5]
rois_fea = fluid.layers.roi_pool(base_fea, rois, pooled_h, pooled_w,
spatial_scale)
return rois_fea
register(kind='ROIPooling', shape=roipooling_shape, layer=roipooling_layer)
""" a custom layer for 'select' which is used to replace standard 'Slice' layer
for converting layer with multiple different output tensors
"""
from .register import register
def select_shape(input_shape, slice_point, axis=1):
""" calculate the output shape of this layer using input shape
Args:
@input_shape (list of num): a list of number which represents the input shape
@slice_point (list): parameter from caffe's Slice layer
@axis (int): parameter from caffe's Slice layer
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
input_shape = list(input_shape)
start = slice_point[0]
if len(slice_point) == 2:
end = slice_point[1]
else:
end = input_shape[axis]
assert end > start, "invalid slice_point with [start:%d, end:%d]"\
% (start, end)
output_shape = input_shape
output_shape[axis] = end - start
return output_shape
def select_layer(input, name, slice_point, axis=1):
""" build a layer of type 'Slice' using fluid
Args:
@input (variable): input fluid variable for this layer
@name (str): name for this layer
@slice_point (list): parameter from caffe's Slice layer
@axis (int): parameter from caffe's Slice layer
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
input_shape = list(input.shape)
start = slice_point[0]
if len(slice_point) == 2:
end = slice_point[1]
else:
end = input_shape[axis]
sections = []
if start > 0:
sections.append(start)
pos = len(sections)
sections.append(end - start)
if end != input_shape[axis]:
sections.append(input_shape[axis] - end)
outputs = fluid.layers.split(input, sections, dim=axis, name=name)
return outputs[pos]
register(kind='Select', shape=select_shape, layer=select_layer)
import sys
#debug level, can be 'warn', 'verbose'
log_level = 'warn'
class KaffeError(Exception):
pass
def print_stderr(msg):
sys.stderr.write('%s\n' % msg)
def debug(msg):
if log_level == 'verbose':
print_stderr('[DEBUG]' + msg)
def notice(msg):
print_stderr('[NOTICE]' + msg)
def warn(msg):
print_stderr('[WARNING]' + msg)
def set_loglevel(level):
global log_level
if 'warn' != level and 'verbose' != level:
raise Exception('not supported log level[%s]' % (level))
log_level = level
from google.protobuf import text_format
from .caffe import get_caffe_resolver
from .errors import KaffeError, print_stderr
from .layers import LayerAdapter, LayerType, NodeKind, NodeDispatch
from .shapes import make_tensor
class Node(object):
def __init__(self, name, kind, layer=None):
self.name = name
self.kind = kind
self.layer = LayerAdapter(layer, kind) if layer else None
self.parents = []
self.children = []
self.data = None #parameters of this node
self.output_shape = None #output shape of this node
self.metadata = {}
def add_parent(self, parent_node):
assert parent_node not in self.parents
self.parents.append(parent_node)
if self not in parent_node.children:
parent_node.children.append(self)
def add_child(self, child_node):
assert child_node not in self.children
self.children.append(child_node)
if self not in child_node.parents:
child_node.parents.append(self)
def get_only_parent(self):
if len(self.parents) != 1:
raise KaffeError('Node (%s) expected to have 1 parent. Found %s.' %
(self, len(self.parents)))
return self.parents[0]
@property
def parameters(self):
""" get parameters stored in a protobuf object
"""
if self.layer is not None:
return self.layer.parameters
return None
@property
def params(self):
""" get parameters stored in a dict
"""
from .protobuf_to_dict import protobuf_to_dict
p = self.parameters
if p is not None:
return protobuf_to_dict(p)
else:
return None
def __str__(self):
return '[%s] %s' % (self.kind, self.name)
def __repr__(self):
return '%s (0x%x)' % (self.name, id(self))
class Graph(object):
def __init__(self, nodes=None, name=None, trace={}):
self.nodes = nodes or []
self.node_lut = {node.name: node for node in self.nodes}
self.output_trace = trace
if name is None or name == '':
self.name = 'MyNet'
else:
self.name = name
def add_node(self, node):
self.nodes.append(node)
self.node_lut[node.name] = node
def get_node(self, name):
try:
return self.node_lut[name]
except KeyError:
raise KaffeError('Layer not found: %s' % name)
def add_name_trace(self, trace, which='caffe'):
self.output_trace[which] = trace
def get_name_trace(self, which=None):
if which is not None:
return self.output_trace[which]
else:
return self.output_trace
def get_input_nodes(self):
return [node for node in self.nodes if len(node.parents) == 0]
def get_output_nodes(self):
return [node for node in self.nodes if len(node.children) == 0]
def topologically_sorted(self):
sorted_nodes = []
unsorted_nodes = list(self.nodes)
temp_marked = set()
perm_marked = set()
def visit(node):
if node in temp_marked:
raise KaffeError('Graph is not a DAG.')
if node in perm_marked:
return
temp_marked.add(node)
for child in node.children:
visit(child)
perm_marked.add(node)
temp_marked.remove(node)
sorted_nodes.insert(0, node)
while len(unsorted_nodes):
visit(unsorted_nodes.pop())
return sorted_nodes
def compute_output_shapes(self):
sorted_nodes = self.topologically_sorted()
for node in sorted_nodes:
node.output_shape = make_tensor(
*NodeKind.compute_output_shape(node))
def replaced(self, new_nodes):
return Graph(nodes=new_nodes, name=self.name, trace=self.output_trace)
def transformed(self, transformers):
graph = self
for transformer in transformers:
graph = transformer(graph)
if graph is None:
raise KaffeError('Transformer failed: {}'.format(transformer))
assert isinstance(graph, Graph)
return graph
def __contains__(self, key):
return key in self.node_lut
def __str__(self):
hdr = '{:<20} {:<30} {:>20} {:>20}'.format('Type', 'Name', 'Param',
'Output')
s = [hdr, '-' * 94]
for node in self.topologically_sorted():
# If the node has learned parameters, display the first one's shape.
# In case of convolutions, this corresponds to the weights.
if node.data is None:
data_shape = '--'
out_shape = node.output_shape or '--'
s.append('{:<20} {:<30} {:>20} {:>20}'.format(
node.kind, node.name, data_shape, tuple(out_shape)))
else:
for d in node.data:
#data_shape = node.data[0].shape if node.data else '--'
data_shape = d.shape
out_shape = node.output_shape or '--'
s.append('{:<20} {:<30} {:>20} {:>20}'.format(
node.kind, node.name, data_shape, tuple(out_shape)))
return '\n'.join(s)
class GraphBuilder(object):
'''Constructs a model graph from a Caffe protocol buffer definition.'''
def __init__(self, def_path, phase='test'):
'''
def_path: Path to the model definition (.prototxt)
data_path: Path to the model data (.caffemodel)
phase: Either 'test' or 'train'. Used for filtering phase-specific nodes.
'''
self.def_path = def_path
self.phase = phase
self.load()
def load(self):
'''Load the layer definitions from the prototxt.'''
self.params = get_caffe_resolver().NetParameter()
with open(self.def_path, 'rb') as def_file:
text_format.Merge(def_file.read(), self.params)
def filter_layers(self, layers):
'''Filter out layers based on the current phase.'''
phase_map = {0: 'train', 1: 'test'}
filtered_layer_names = set()
filtered_layers = []
for layer in layers:
phase = self.phase
if len(layer.include):
phase = phase_map[layer.include[0].phase]
if len(layer.exclude):
phase = phase_map[1 - layer.include[0].phase]
exclude = (phase != self.phase)
# Dropout layers appear in a fair number of Caffe
# test-time networks. These are just ignored. We'll
# filter them out here.
if (not exclude) and (phase == 'test'):
exclude = (layer.type == LayerType.Dropout)
if not exclude:
filtered_layers.append(layer)
# Guard against dupes.
assert layer.name not in filtered_layer_names
filtered_layer_names.add(layer.name)
return filtered_layers
def make_node(self, layer):
'''Create a graph node for the given layer.'''
kind = NodeKind.map_raw_kind(layer.type)
if kind is None:
raise KaffeError('Unknown layer type encountered: %s' % layer.type)
# We want to use the layer's top names (the "output" names), rather than the
# name attribute, which is more of readability thing than a functional one.
# Other layers will refer to a node by its "top name".
return Node(layer.name, kind, layer=layer)
def make_input_nodes(self):
'''
Create data input nodes.
This method is for old-style inputs, where the input specification
was not treated as a first-class layer in the prototext.
Newer models use the "Input layer" type.
'''
nodes = [Node(name, NodeKind.Data) for name in self.params.input]
inputs_num = len(nodes)
if inputs_num > 0:
input_dims_num = len(self.params.input_dim)
if input_dims_num > 0 and input_dims_num != inputs_num * 4:
raise KaffeError('invalid input_dim[%d] param in prototxt' %
(input_dims_num))
input_dims = [[]] * inputs_num
for i in range(input_dims_num):
dim = self.params.input_dim[i]
which = int(i / 4)
input_dims[which].append(int(dim))
for i in range(inputs_num):
if len(self.params.input_shape) == inputs_num:
input_dim = map(int, self.params.input_shape[i].dim)
input_dims[i] = input_dim
nodes[i].output_shape = tuple(input_dims[i])
return nodes
def build(self):
'''
Builds the graph from the Caffe layer definitions.
'''
# Get the layers
layers = self.params.layers or self.params.layer
# Filter out phase-excluded layers
layers = self.filter_layers(layers)
# Get any separately-specified input layers
nodes = self.make_input_nodes()
nodes += [self.make_node(layer) for layer in layers]
# Initialize the graph
graph = Graph(nodes=nodes, name=self.params.name)
# Connect the nodes
#
# A note on layers and outputs:
# In Caffe, each layer can produce multiple outputs ("tops") from a set of inputs
# ("bottoms"). The bottoms refer to other layers' tops. The top can rewrite a bottom
# (in case of in-place operations). Note that the layer's name is not used for establishing
# any connectivity. It's only used for data association. By convention, a layer with a
# single top will often use the same name (although this is not required).
#
# The current implementation only supports single-output nodes (note that a node can still
# have multiple children, since multiple child nodes can refer to the single top's name).
node_outputs = {}
output_trace = {}
for layer in layers:
node = graph.get_node(layer.name)
for input_name in layer.bottom:
assert input_name != layer.name
parent_node = node_outputs.get(input_name)
if (parent_node is None) or (parent_node == node):
parent_node = graph.get_node(input_name)
node.add_parent(parent_node)
if len(layer.top) > 1:
raise KaffeError('Multiple top nodes are not supported.')
for output_name in layer.top:
if output_name == layer.name:
# Output is named the same as the node. No further action required.
continue
# There are two possibilities here:
#
# Case 1: output_name refers to another node in the graph.
# This is an "in-place operation" that overwrites an existing node.
# This would create a cycle in the graph. We'll undo the in-placing
# by substituting this node wherever the overwritten node is referenced.
#
# Case 2: output_name violates the convention layer.name == output_name.
# Since we are working in the single-output regime, we will can rename it to
# match the layer name.
#
# For both cases, future references to this top re-routes to this node.
node_outputs[output_name] = node
if output_name in output_trace:
output_trace[output_name].append(node.name)
else:
output_trace[output_name] = [output_name, node.name]
#build a mapping from real-name to changed-name(for caffe's INPLACE inference)
real2chg = {}
deleted = {}
for k, v in output_trace.items():
real2chg[v[-1]] = k
for n in v:
if n in real2chg:
continue
if n not in deleted:
deleted[n] = '%s.%s' % (k, v[-1])
graph.add_name_trace({
'real2chg': real2chg,
'deleted': deleted
}, 'caffe')
graph.compute_output_shapes()
return graph
class NodeMapper(NodeDispatch):
def __init__(self, graph):
self.graph = graph
def map(self):
nodes = self.graph.topologically_sorted()
# Remove input nodes - we'll handle them separately.
input_nodes = self.graph.get_input_nodes()
nodes = [t for t in nodes if t not in input_nodes]
# Decompose DAG into chains.
chains = []
for node in nodes:
attach_to_chain = None
if len(node.parents) == 1:
parent = node.get_only_parent()
for chain in chains:
if chain[-1] == parent:
# Node is part of an existing chain.
attach_to_chain = chain
break
if attach_to_chain is None:
# Start a new chain for this node.
attach_to_chain = []
chains.append(attach_to_chain)
attach_to_chain.append(node)
# Map each chain.
mapped_chains = []
for chain in chains:
mapped_chains.append(self.map_chain(chain))
return self.commit(mapped_chains)
def map_chain(self, chain):
return [self.map_node(node) for node in chain]
def map_node(self, node):
map_func = self.get_handler(node.kind, 'map')
mapped_node = map_func(node)
assert mapped_node is not None
mapped_node.node = node
return mapped_node
def commit(self, mapped_chains):
raise NotImplementedError('Must be implemented by subclass.')
import re
import numbers
from collections import namedtuple
import custom_layers
from .shapes import *
LAYER_DESCRIPTORS = {
# Caffe Types
'AbsVal': shape_identity,
'Accuracy': shape_scalar,
'ArgMax': shape_not_implemented,
'BatchNorm': shape_identity,
'BNLL': shape_not_implemented,
'Concat': shape_concat,
'ContrastiveLoss': shape_scalar,
'Convolution': shape_convolution,
'Deconvolution': shape_deconvolution,
'Data': shape_data,
'Dropout': shape_identity,
'DummyData': shape_data,
'Crop': shape_crop,
'EuclideanLoss': shape_scalar,
'Eltwise': shape_identity,
'Exp': shape_identity,
'Flatten': shape_not_implemented,
'HDF5Data': shape_data,
'HDF5Output': shape_identity,
'HingeLoss': shape_scalar,
'Im2col': shape_not_implemented,
'ImageData': shape_data,
'InfogainLoss': shape_scalar,
'InnerProduct': shape_inner_product,
'Input': shape_data,
'LRN': shape_identity,
'MemoryData': shape_mem_data,
'MultinomialLogisticLoss': shape_scalar,
'MVN': shape_not_implemented,
'Pooling': shape_pool,
'Power': shape_power,
'ReLU': shape_identity,
'PReLU': shape_identity,
'Scale': shape_identity,
'Sigmoid': shape_identity,
'SigmoidCrossEntropyLoss': shape_scalar,
'Silence': shape_not_implemented,
'Softmax': shape_identity,
'SoftmaxWithLoss': shape_scalar,
'Split': shape_not_implemented,
'Slice': shape_not_implemented,
'TanH': shape_identity,
'WindowData': shape_not_implemented,
'Threshold': shape_identity,
}
# layer types in 'V1LayerParameter'
# (v1layertype name, enum value, mapped to layer type)
v1_layertypes = [
('ABSVAL', 35),
('ACCURACY', 1),
('ARGMAX', 30),
('BNLL', 2),
('CONCAT', 3),
('CONVOLUTION', 4),
('DATA', 5),
('DECONVOLUTION', 39),
('DROPOUT', 6),
('ELTWISE', 25),
('EXP', 38),
('FLATTEN', 8),
('IM2COL', 11),
('INNERPRODUCT', 14),
('LRN', 15),
('MEMORYDATA', 29),
('MULTINOMIALLOGISTICLOSS', 16),
('MVN', 34),
('POOLING', 17),
('POWER', 26),
('RELU', 18),
('SIGMOID', 19),
('SIGMOIDCROSSENTROPYLOSS', 27),
('SILENCE', 36),
('SOFTMAX', 20),
('SPLIT', 22),
('SLICE', 33),
('TANH', 23),
('WINDOWDATA', 24),
('THRESHOLD', 31),
]
LAYER_TYPES = LAYER_DESCRIPTORS.keys()
LayerType = type('LayerType', (), {t: t for t in LAYER_TYPES})
#map the layer name in V1 to standard name
V1_LAYER_MAP = {'_not_init_': True}
def get_v1_layer_map():
global V1_LAYER_MAP
if '_not_init_' not in V1_LAYER_MAP:
return V1_LAYER_MAP
else:
del V1_LAYER_MAP['_not_init_']
name2layer = {}
for n in LAYER_TYPES:
name2layer[n.upper()] = n
for l in v1_layertypes:
n, v = l
if n in name2layer and v not in V1_LAYER_MAP:
V1_LAYER_MAP[v] = name2layer[n]
else:
raise KaffeError('not found v1 layer type %s' % n)
return V1_LAYER_MAP
class NodeKind(LayerType):
@staticmethod
def map_raw_kind(kind):
if custom_layers.has_layer(kind):
return kind
if kind in LAYER_TYPES:
return kind
v1_layers = get_v1_layer_map()
if kind in v1_layers:
return v1_layers[kind]
else:
return None
@staticmethod
def compute_output_shape(node):
if custom_layers.has_layer(node.kind):
return custom_layers.compute_output_shape(node.kind, node)
try:
val = LAYER_DESCRIPTORS[node.kind](node)
return val
except NotImplementedError:
raise KaffeError(
'Output shape computation not implemented for type: %s' %
node.kind)
class NodeDispatchError(KaffeError):
pass
class NodeDispatch(object):
@staticmethod
def get_handler_name(node_kind):
if len(node_kind) <= 6:
# A catch-all for things like ReLU and tanh
return node_kind.lower()
# Convert from CamelCase to under_scored
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', node_kind)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
def get_handler(self, node_kind, prefix):
if custom_layers.has_layer(node_kind):
return getattr(self, 'map_custom')
name = self.get_handler_name(node_kind)
name = '_'.join((prefix, name))
try:
return getattr(self, name)
except AttributeError:
raise NodeDispatchError(
'No handler found for node kind: %s (expected: %s)' %
(node_kind, name))
class LayerAdapter(object):
def __init__(self, layer, kind):
self.layer = layer
self.kind = kind
@property
def parameters(self):
name = NodeDispatch.get_handler_name(self.kind)
if self.kind.lower() == "normalize":
name = "norm"
elif self.kind.lower() == "deconvolution":
name = "convolution"
name = '_'.join((name, 'param'))
try:
return getattr(self.layer, name)
except AttributeError:
print(dir(self.layer))
raise NodeDispatchError(
'Caffe parameters not found attr[%s] for layer kind[%s]' %
(name, self.kind))
@staticmethod
def get_kernel_value(scalar, repeated, idx, default=None):
if scalar:
return scalar
if repeated:
if isinstance(repeated, numbers.Number):
return repeated
if len(repeated) == 1:
# Same value applies to all spatial dimensions
return int(repeated[0])
assert idx < len(repeated)
# Extract the value for the given spatial dimension
return repeated[idx]
if default is None:
raise ValueError('Unable to determine kernel parameter!')
return default
@property
def kernel_parameters(self):
assert self.kind in (NodeKind.Convolution, NodeKind.Pooling,\
NodeKind.Deconvolution)
params = self.parameters
k_h = self.get_kernel_value(params.kernel_h, params.kernel_size, 0)
k_w = self.get_kernel_value(params.kernel_w, params.kernel_size, 1)
s_h = self.get_kernel_value(
params.stride_h, params.stride, 0, default=1)
s_w = self.get_kernel_value(
params.stride_w, params.stride, 1, default=1)
p_h = self.get_kernel_value(params.pad_h, params.pad, 0, default=0)
p_w = self.get_kernel_value(params.pad_w, params.pad, 1, default=0)
dila_h = dila_w = 1
if self.kind in (NodeKind.Convolution, NodeKind.Deconvolution):
dila_len = len(params.dilation)
if dila_len == 2:
dila_h = params.dilation[0]
dila_w = params.dilation[1]
elif dila_len == 1:
dila_h = dila_w = params.dilation[0]
else:
assert dila_len == 0, "invalid length[%s] of dilation in convolution" % (
dila_len)
return KernelParameters(k_h, k_w, s_h, s_w, p_h, p_w, dila_h, dila_w)
KernelParameters = namedtuple(
'KernelParameters',
[
'kernel_h', 'kernel_w', 'stride_h', 'stride_w', 'pad_h', 'pad_w',
'dila_h', 'dila_w'
], )
""" this module is used as a template for generating sub class of Network
"""
class MyNet(object):
### automatically generated by caffe2fluid ###
inputs_info = "INPUTS_INFO"
custom_layers_path = "_CAFFE2FLUID_CUSTOM_LAYERS_"
def custom_layer_factory(self):
import os
pk_paths = []
default = os.path.dirname(os.path.abspath(__file__))
location = os.environ.get('CAFFE2FLUID_CUSTOM_LAYERS', default)
pk_name = 'custom_layers'
pk_dir = os.path.join(location, pk_name)
pk_paths.append((location, pk_dir))
location = MyNet.custom_layers_path
pk_dir = os.path.join(MyNet.custom_layers_path, pk_name)
pk_paths.append((location, pk_dir))
for loc, pk_dir in pk_paths:
if os.path.exists(pk_dir):
if loc not in sys.path:
sys.path.insert(0, loc)
break
try:
from custom_layers import make_custom_layer
return make_custom_layer
except Exception as e:
print('maybe you should set $CAFFE2FLUID_CUSTOM_LAYERS first')
raise e
@classmethod
def input_shapes(cls):
return cls.inputs_info
@classmethod
def convert(cls, npy_model, fluid_path, outputs=None):
fluid = import_fluid()
shapes = cls.input_shapes()
input_name = shapes.keys()[0]
feed_data = {}
for name, shape in shapes.items():
data_layer = fluid.layers.data(
name=name, shape=shape, dtype="float32")
feed_data[name] = data_layer
net = cls(feed_data)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
net.load(data_path=npy_model, exe=exe, place=place)
output_vars = []
model_filename = 'model'
params_filename = 'params'
if outputs is None:
output_vars.append(net.get_output())
else:
if outputs[0] == 'dump_all':
model_filename = None
params_filename = None
output_vars.append(net.get_output())
else:
if type(outputs) is list:
for n in outputs:
assert n in net.layers, 'not found layer with this name[%s]' % (
n)
output_vars.append(net.layers[n])
fluid.io.save_inference_model(
fluid_path, [input_name],
output_vars,
exe,
main_program=None,
model_filename=model_filename,
params_filename=params_filename)
return 0
def main():
""" a tool used to convert caffe model to fluid
"""
import sys
import os
filename = os.path.splitext(os.path.basename(sys.argv[0]))[0]
if len(sys.argv) < 3:
print('usage:')
print(' python %s %s.npy [save_dir] [layer names seperated by comma]' \
% (sys.argv[0], filename))
print(' eg: python %s %s.npy ./fluid' % (sys.argv[0], filename))
print(' eg: python %s %s.npy ./fluid layer_name1,layer_name2' \
% (sys.argv[0], filename))
return 1
npy_weight = sys.argv[1]
fluid_model = sys.argv[2]
outputs = None
if len(sys.argv) >= 4:
outputs = sys.argv[3].split(',')
ret = MyNet.convert(npy_weight, fluid_model, outputs)
if ret == 0:
outputs = 'last output layer' if outputs is None else outputs
print('succeed to convert to fluid format with output layers[%s]'
' in directory[%s]' % (outputs, fluid_model))
else:
print('failed to convert model to fluid format')
return ret
def generate_net_code(net_name, inputs_info):
""" generate framework of a custom net code which represent a subclass of Network
Args:
@net_name (str): class name for this net
@inputs_info (str): a str which represents a dict, eg: '{"data": [3, 32, 32]}'
Returns:
net_codes (str): codes for this subclass
"""
import os
import inspect
net_codes = str(inspect.getsource(MyNet))
net_codes = net_codes.replace('MyNet(object)', '%s(Network)' % net_name)
net_codes = net_codes.replace('MyNet', net_name)
net_codes = net_codes.replace('"INPUTS_INFO"', inputs_info)
custom_layer_dir = os.path.dirname(os.path.abspath(__file__))
net_codes = net_codes.replace('_CAFFE2FLUID_CUSTOM_LAYERS_',
custom_layer_dir)
return net_codes
def generate_main_code(net_name):
""" generate a piece of code for 'main' function
Args:
@net_name (str): class name for this net
Returns:
main_codes (str): codes for this main function
"""
import inspect
main_codes = str(inspect.getsource(main))
main_codes = main_codes.replace('MyNet', net_name)
return main_codes
if __name__ == "__main__":
""" just for testing
"""
print generate_net_code('Attribute', "{'data': [3, 277, 277]}")
print generate_main_code('Attribute')
from .transformer import Transformer
from .network import Network
import sys
import os
import math
import numpy as np
def import_fluid():
import paddle.fluid as fluid
return fluid
def layer(op):
'''Decorator for composable network layers.'''
def layer_decorated(self, *args, **kwargs):
# Automatically set a name if not provided.
name = kwargs.setdefault('name', self.get_unique_name(op.__name__))
# Figure out the layer inputs.
if len(self.terminals) == 0:
raise RuntimeError('No input variables found for layer %s.' % name)
elif len(self.terminals) == 1:
layer_input = self.terminals[0]
else:
layer_input = list(self.terminals)
self.layer_reverse_trace[name] = layer_input
# Perform the operation and get the output.
layer_output = op(self, layer_input, *args, **kwargs)
# Add to layer LUT.
self.layers[name] = layer_output
self.var2name[layer_output.name] = (name, layer_output)
# This output is now the input for the next layer.
self.feed(layer_output)
# Return self for chained calls.
return self
return layer_decorated
class Network(object):
def __init__(self, inputs, trainable=True):
# The input nodes for this network
self.inputs = inputs
# The current list of terminal nodes
self.terminals = []
# Mapping from layer names to layers
self.layers = dict(inputs)
# If true, the resulting variables are set as trainable
self.trainable = trainable
# Switch variable for dropout
self.paddle_env = None
self.output_names = []
self.name_trace = None
self.layer_reverse_trace = {}
self.var2name = {}
self.setup()
def setup(self):
'''Construct the network. '''
raise NotImplementedError('Must be implemented by the subclass.')
def locate_ancestor(self, v, which=[0], ancestor_level=1):
""" find a ancestor for a node 'v' which is a fluid variable
"""
ancestor = None
which = which * ancestor_level
name = self.var2name[v.name][0]
for i in range(ancestor_level):
v = self.layer_reverse_trace[name]
if type(v) is list:
ancestor = self.var2name[v[which[i]].name]
else:
ancestor = self.var2name[v.name]
name = ancestor[0]
return ancestor
def load(self, data_path, exe=None, place=None, ignore_missing=False):
'''Load network weights.
data_path: The path to the numpy-serialized network weights
ignore_missing: If true, serialized weights for missing layers are ignored.
'''
fluid = import_fluid()
#load fluid mode directly
if os.path.isdir(data_path):
assert (exe is not None), \
'must provide a executor to load fluid model'
fluid.io.load_persistables(executor=exe, dirname=data_path)
return True
#load model from a npy file
if exe is None or place is None:
if self.paddle_env is None:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
self.paddle_env = {'place': place, 'exe': exe}
exe = exe.run(fluid.default_startup_program())
else:
place = self.paddle_env['place']
exe = self.paddle_env['exe']
data_dict = np.load(data_path).item()
for op_name in data_dict:
if op_name == 'caffe2fluid_name_trace':
self.name_trace = data_dict[op_name]
continue
layer = self.layers[op_name]
for param_name, data in data_dict[op_name].iteritems():
try:
name = '%s_%s' % (op_name, param_name)
v = fluid.global_scope().find_var(name)
w = v.get_tensor()
w.set(data.reshape(w.shape()), place)
except ValueError:
if not ignore_missing:
raise
return True
def feed(self, *args):
'''Set the input(s) for the next operation by replacing the terminal nodes.
The arguments can be either layer names or the actual layers.
'''
assert len(args) != 0
self.terminals = []
for fed_layer in args:
if isinstance(fed_layer, basestring):
try:
fed_layer = self.layers[fed_layer]
except KeyError:
raise KeyError('Unknown layer name fed: %s' % fed_layer)
self.terminals.append(fed_layer)
return self
def get_output(self):
'''Returns the current network output.'''
return self.terminals[-1]
def get_unique_name(self, prefix):
'''Returns an index-suffixed unique name for the given prefix.
This is used for auto-generating layer names based on the type-prefix.
'''
ident = sum(t.startswith(prefix) for t, _ in self.layers.items()) + 1
return '%s_%d' % (prefix, ident)
def get_unique_output_name(self, prefix, layertype):
'''Returns an index-suffixed unique name for the given prefix.
This is used for auto-generating layer names based on the type-prefix.
'''
ident = sum(t.startswith(prefix) for t in self.output_names) + 1
unique_name = '%s.%s.output.%d' % (prefix, layertype, ident)
self.output_names.append(unique_name)
return unique_name
@layer
def conv(self,
input,
k_h,
k_w,
c_o,
s_h,
s_w,
name,
relu=True,
relu_negative_slope=0.0,
padding=None,
dilation=1,
group=1,
biased=True):
if padding is None:
padding = [0, 0]
# Get the number of channels in the input
c_i, h_i, w_i = input.shape[1:]
# Verify that the grouping parameter is valid
assert c_i % group == 0
assert c_o % group == 0
fluid = import_fluid()
prefix = name + '_'
leaky_relu = False
act = 'relu'
if relu is False:
act = None
elif relu_negative_slope != 0.0:
leaky_relu = True
act = None
output = fluid.layers.conv2d(
name=self.get_unique_output_name(name, 'conv2d'),
input=input,
filter_size=[k_h, k_w],
num_filters=c_o,
stride=[s_h, s_w],
padding=padding,
dilation=dilation,
groups=group,
param_attr=fluid.ParamAttr(name=prefix + "weights"),
bias_attr=fluid.ParamAttr(name=prefix + "biases"),
act=act)
if leaky_relu:
output = fluid.layers.leaky_relu(output, alpha=relu_negative_slope)
return output
@layer
def deconv(self,
input,
k_h,
k_w,
c_o,
s_h,
s_w,
name,
relu=True,
relu_negative_slope=0.0,
padding=None,
dilation=1,
biased=True):
if padding is None:
padding = [0, 0]
# Get the number of channels in the input
c_i, h_i, w_i = input.shape[1:]
fluid = import_fluid()
prefix = name + '_'
leaky_relu = False
act = 'relu'
if relu is False:
act = None
elif relu_negative_slope != 0.0:
leaky_relu = True
act = None
p_h = padding[0]
p_w = padding[1]
h_o = (h_i - 1) * s_h - 2 * p_h + dilation * (k_h - 1) + 1
w_o = (w_i - 1) * s_w - 2 * p_w + dilation * (k_w - 1) + 1
output = fluid.layers.conv2d_transpose(
name=self.get_unique_output_name(name, 'conv2d_transpose'),
input=input,
num_filters=c_o,
output_size=[h_o, w_o],
filter_size=[k_h, k_w],
padding=padding,
stride=[s_h, s_w],
dilation=dilation,
param_attr=fluid.ParamAttr(name=prefix + "weights"),
bias_attr=fluid.ParamAttr(name=prefix + "biases"),
act=act)
if leaky_relu:
output = fluid.layers.leaky_relu(output, alpha=relu_negative_slope)
return output
@layer
def relu(self, input, name):
fluid = import_fluid()
output = fluid.layers.relu(input)
return output
@layer
def prelu(self, input, channel_shared, name):
fluid = import_fluid()
if channel_shared:
mode = 'all'
else:
mode = 'channel'
prefix = name + '_'
output = fluid.layers.prelu(
input,
mode=mode,
param_attr=fluid.ParamAttr(name=prefix + 'negslope'))
return output
def pool(self,
pool_type,
input,
k_h,
k_w,
s_h,
s_w,
ceil_mode,
padding,
name,
exclusive=True):
# Get the number of channels in the input
in_hw = input.shape[2:]
k_hw = [k_h, k_w]
s_hw = [s_h, s_w]
fluid = import_fluid()
output = fluid.layers.pool2d(
name=name,
input=input,
pool_size=k_hw,
pool_stride=s_hw,
pool_padding=padding,
ceil_mode=ceil_mode,
pool_type=pool_type,
exclusive=exclusive)
return output
@layer
def max_pool(self,
input,
k_h,
k_w,
s_h,
s_w,
ceil_mode,
padding=[0, 0],
name=None):
return self.pool(
'max',
input,
k_h,
k_w,
s_h,
s_w,
ceil_mode,
padding,
name=self.get_unique_output_name(name, 'max_pool'))
@layer
def avg_pool(self,
input,
k_h,
k_w,
s_h,
s_w,
ceil_mode,
padding=[0, 0],
name=None):
return self.pool(
'avg',
input,
k_h,
k_w,
s_h,
s_w,
ceil_mode,
padding,
name=self.get_unique_output_name(name, 'avg_pool'),
exclusive=False)
@layer
def sigmoid(self, input, name):
fluid = import_fluid()
return fluid.layers.sigmoid(
input, name=self.get_unique_output_name(name, 'sigmoid'))
@layer
def tanh(self, input, name):
fluid = import_fluid()
return fluid.layers.tanh(
input, name=self.get_unique_output_name(name, 'tanh'))
@layer
def lrn(self, input, radius, alpha, beta, name, bias=1.0):
fluid = import_fluid()
output = fluid.layers.lrn(input=input,
n=radius,
k=bias,
alpha=alpha,
beta=beta,
name=self.get_unique_output_name(name, 'lrn'))
return output
@layer
def concat(self, inputs, axis, name):
fluid = import_fluid()
output = fluid.layers.concat(
input=inputs,
axis=axis,
name=self.get_unique_output_name(name, 'concat'))
return output
@layer
def add(self, inputs, name):
fluid = import_fluid()
output = inputs[0]
for i in inputs[1:]:
output = fluid.layers.elementwise_add(
x=output, y=i, name=self.get_unique_output_name(name, 'add'))
return output
@layer
def max(self, inputs, name):
fluid = import_fluid()
output = inputs[0]
for i in inputs[1:]:
output = fluid.layers.elementwise_max(
x=output, y=i, name=self.get_unique_output_name(name, 'max'))
return output
@layer
def multiply(self, inputs, name):
fluid = import_fluid()
output = inputs[0]
for i in inputs[1:]:
output = fluid.layers.elementwise_mul(
x=output, y=i, name=self.get_unique_output_name(name, 'mul'))
return output
@layer
def fc(self, input, num_out, name, relu=True, act=None):
fluid = import_fluid()
if act is None:
act = 'relu' if relu is True else None
prefix = name + '_'
output = fluid.layers.fc(
name=self.get_unique_output_name(name, 'fc'),
input=input,
size=num_out,
act=act,
param_attr=fluid.ParamAttr(name=prefix + 'weights'),
bias_attr=fluid.ParamAttr(name=prefix + 'biases'))
return output
@layer
def softmax(self, input, axis=2, name=None):
fluid = import_fluid()
shape = input.shape
dims = len(shape)
axis = axis + dims if axis < 0 else axis
need_transpose = False
if axis + 1 != dims:
need_transpose = True
if need_transpose:
order = range(dims)
order.remove(axis)
order.append(axis)
input = fluid.layers.transpose(
input,
perm=order,
name=self.get_unique_output_name(name, 'transpose'))
output = fluid.layers.softmax(
input, name=self.get_unique_output_name(name, 'softmax'))
if need_transpose:
order = range(len(shape))
order[axis] = dims - 1
order[-1] = axis
output = fluid.layers.transpose(
output,
perm=order,
name=self.get_unique_output_name(name, 'transpose'))
return output
@layer
def batch_normalization(self,
input,
name,
scale_offset=True,
eps=1e-5,
relu=False,
relu_negative_slope=0.0):
# NOTE: Currently, only inference is supported
fluid = import_fluid()
prefix = name + '_'
param_attr = None if scale_offset is False else fluid.ParamAttr(
name=prefix + 'scale')
bias_attr = None if scale_offset is False else fluid.ParamAttr(
name=prefix + 'offset')
mean_name = prefix + 'mean'
variance_name = prefix + 'variance'
leaky_relu = False
act = 'relu'
if relu is False:
act = None
elif relu_negative_slope != 0.0:
leaky_relu = True
act = None
output = fluid.layers.batch_norm(
name=self.get_unique_output_name(name, 'batch_norm'),
input=input,
is_test=True,
param_attr=param_attr,
bias_attr=bias_attr,
moving_mean_name=mean_name,
moving_variance_name=variance_name,
epsilon=eps,
act=act)
if leaky_relu:
output = fluid.layers.leaky_relu(output, alpha=relu_negative_slope)
return output
@layer
def dropout(self, input, drop_prob, name, is_test=True):
fluid = import_fluid()
if is_test:
output = input
else:
output = fluid.layers.dropout(
input,
dropout_prob=drop_prob,
is_test=is_test,
name=self.get_unique_output_name(name, 'dropout'))
return output
@layer
def scale(self, input, axis=1, num_axes=1, name=None):
fluid = import_fluid()
assert num_axes == 1, "layer scale not support this num_axes[%d] now" % (
num_axes)
prefix = name + '_'
scale_shape = input.shape[axis:axis + num_axes]
param_attr = fluid.ParamAttr(name=prefix + 'scale')
scale_param = fluid.layers.create_parameter(
shape=scale_shape,
dtype=input.dtype,
name=name,
attr=param_attr,
is_bias=True,
default_initializer=fluid.initializer.Constant(value=1.0))
offset_attr = fluid.ParamAttr(name=prefix + 'offset')
offset_param = fluid.layers.create_parameter(
shape=scale_shape,
dtype=input.dtype,
name=name,
attr=offset_attr,
is_bias=True,
default_initializer=fluid.initializer.Constant(value=0.0))
output = fluid.layers.elementwise_mul(
input,
scale_param,
axis=axis,
name=self.get_unique_output_name(name, 'scale_mul'))
output = fluid.layers.elementwise_add(
output,
offset_param,
axis=axis,
name=self.get_unique_output_name(name, 'scale_add'))
return output
def custom_layer_factory(self):
""" get a custom layer maker provided by subclass
"""
raise NotImplementedError(
'[custom_layer_factory] must be implemented by the subclass.')
@layer
def custom_layer(self, inputs, kind, name, *args, **kwargs):
""" make custom layer
"""
#FIX ME:
# there is a trick for different API between caffe and paddle
if kind == "DetectionOutput":
conf_var = inputs[1]
real_conf_var = self.locate_ancestor(conf_var, ancestor_level=2)
inputs[1] = real_conf_var[1]
name = self.get_unique_output_name(name, kind)
layer_factory = self.custom_layer_factory()
return layer_factory(kind, inputs, name, *args, **kwargs)
import numpy as np
from ..errors import KaffeError, print_stderr
from ..graph import GraphBuilder, NodeMapper
from ..layers import NodeKind
from ..transformers import (DataInjector, DataReshaper, NodeRenamer,
SubNodeFuser, ReLUFuser, BatchNormScaleBiasFuser,
BatchNormPreprocessor, ParameterNamer, CropFuser)
from . import network
class PaddleNode(object):
'''An intermediate representation for Paddle operations.'''
def __init__(self, op, *args, **kwargs):
# A string corresponding to the Paddle operation
self.op = op
# Positional arguments for the operation
self.args = args
# Keyword arguments for the operation
self.kwargs = list(kwargs.items())
# The source Caffe node
self.node = None
def format(self, arg):
'''Returns a string representation for the given value.'''
return "'%s'" % arg if isinstance(arg, basestring) else str(arg)
def pair(self, key, value):
'''Returns key=formatted(value).'''
return '%s=%s' % (key, self.format(value))
def emit(self):
'''Emits the Python source for this node.'''
# Format positional arguments
args = map(self.format, self.args)
# Format any keyword arguments
if self.kwargs:
args += [self.pair(k, v) for k, v in self.kwargs]
# Set the node name
args.append(self.pair('name', self.node.name))
args = ', '.join(args)
return '%s(%s)' % (self.op, args)
class MaybeActivated(object):
def __init__(self, node, default=True):
self.inject_kwargs = {}
if node.metadata.get('relu', False) != default:
self.inject_kwargs['relu'] = not default
default_slope = 0.0
slope = node.metadata.get('relu_negative_slope', default_slope)
if slope != default_slope:
self.inject_kwargs['relu_negative_slope'] = slope
def __call__(self, *args, **kwargs):
kwargs.update(self.inject_kwargs)
return PaddleNode(*args, **kwargs)
class PaddleMapper(NodeMapper):
def get_kernel_params(self, node):
kernel_params = node.layer.kernel_parameters
input_shape = node.get_only_parent().output_shape
padding = [kernel_params.pad_h, kernel_params.pad_w]
if padding[0] == 0 and padding[1] == 0:
padding = {}
else:
padding = {'padding': padding}
return (kernel_params, padding)
def map_convolution(self, node):
(kernel_params, kwargs) = self.get_kernel_params(node)
h = kernel_params.kernel_h
w = kernel_params.kernel_w
c_o = node.output_shape[1]
c_i = node.parents[0].output_shape[1]
group = node.parameters.group
if group != 1:
kwargs['group'] = group
if not node.parameters.bias_term:
kwargs['biased'] = False
if kernel_params.dila_h != 1 or kernel_params.dila_w != 1:
kwargs['dilation'] = (kernel_params.dila_h, kernel_params.dila_w)
assert kernel_params.kernel_h == h
assert kernel_params.kernel_w == w
return MaybeActivated(node)(
'conv', kernel_params.kernel_h, kernel_params.kernel_w, c_o,
kernel_params.stride_h, kernel_params.stride_w, **kwargs)
def map_deconvolution(self, node):
(kernel_params, kwargs) = self.get_kernel_params(node)
h = kernel_params.kernel_h
w = kernel_params.kernel_w
c_o = node.output_shape[1]
c_i = node.parents[0].output_shape[1]
if not node.parameters.bias_term:
kwargs['biased'] = False
if kernel_params.dila_h != 1 or kernel_params.dila_w != 1:
kwargs['dilation'] = (kernel_params.dila_h, kernel_params.dila_w)
assert kernel_params.kernel_h == h
assert kernel_params.kernel_w == w
return MaybeActivated(node)(
'deconv', kernel_params.kernel_h, kernel_params.kernel_w, c_o,
kernel_params.stride_h, kernel_params.stride_w, **kwargs)
def map_relu(self, node):
return PaddleNode('relu')
def map_prelu(self, node):
channel_shared = getattr(node.parameters, 'channel_shared', False)
return PaddleNode('prelu', channel_shared)
def map_tanh(self, node):
return PaddleNode('tanh')
def map_pooling(self, node):
pool_type = node.parameters.pool
if pool_type == 0:
pool_op = 'max_pool'
elif pool_type == 1:
pool_op = 'avg_pool'
else:
# Stochastic pooling, for instance.
raise KaffeError('Unsupported pooling type.')
ceil_mode = getattr(node.layer.parameters, 'ceil_mode', True)
global_pool = getattr(node.layer.parameters, 'global_pooling', False)
if global_pool:
input_shape = node.get_only_parent().output_shape
return PaddleNode(pool_op, input_shape.height, input_shape.width, 1,
1, ceil_mode)
else:
(kernel_params, padding) = self.get_kernel_params(node)
return PaddleNode(pool_op, kernel_params.kernel_h,
kernel_params.kernel_w, kernel_params.stride_h,
kernel_params.stride_w, ceil_mode, **padding)
def map_sigmoid(self, node):
return PaddleNode('sigmoid')
def map_custom(self, node):
from .. import custom_layers
return custom_layers.make_node(PaddleNode, node.kind, node)
def map_inner_product(self, node):
#TODO: Axis
assert node.parameters.axis == 1
#TODO: Unbiased
assert node.parameters.bias_term == True
return MaybeActivated(node)('fc', node.parameters.num_output)
def map_softmax(self, node):
return PaddleNode('softmax', node.parameters.axis)
def map_lrn(self, node):
params = node.parameters
# The window size must be an odd value. For a window
# size of (2*n+1), Paddle defines depth_radius = n.
assert params.local_size % 2 == 1
# Caffe scales by (alpha/(2*n+1)), whereas Paddle
# just scales by alpha (as does Krizhevsky's paper).
# We'll account for that here.
alpha = params.alpha / float(params.local_size)
return PaddleNode('lrn', params.local_size, alpha, params.beta)
def map_concat(self, node):
return PaddleNode('concat', node.parameters.axis)
def map_dropout(self, node):
return PaddleNode('dropout', node.parameters.dropout_ratio)
def map_batch_norm(self, node):
scale_offset = len(node.data) == 4
#this default value comes from caffe's param in batch_norm
default_eps = 1e-5
kwargs = {'scale_offset': scale_offset}
if node.parameters.eps != default_eps:
kwargs['eps'] = node.parameters.eps
return MaybeActivated(
node, default=False)('batch_normalization', **kwargs)
def map_eltwise(self, node):
operations = {0: 'multiply', 1: 'add', 2: 'max'}
op_code = node.parameters.operation
try:
return PaddleNode(operations[op_code])
except KeyError:
raise KaffeError('Unknown elementwise operation: {}'.format(
op_code))
def map_scale(self, node):
params = node.parameters
return PaddleNode('scale', axis=params.axis, num_axes=params.num_axes)
def commit(self, chains):
return chains
class PaddleEmitter(object):
def __init__(self, tab=None):
self.tab = tab or ' ' * 4
self.prefix = ''
self.net_name = ''
def indent(self):
self.prefix += self.tab
def outdent(self):
self.prefix = self.prefix[:-len(self.tab)]
def statement(self, s):
return self.prefix + s + '\n'
def emit_imports(self):
import inspect
codes = []
codes.append(
'### generated by caffe2fluid, your net is in class "%s" ###\n' %
(self.net_name))
network_source = inspect.getsource(network)
codes.append(network_source + '\n')
return self.statement('\n'.join(codes))
def emit_setup_def(self):
return self.statement('def setup(self):')
def get_inputs_info(self, input_nodes):
input_shapes = {}
for n in input_nodes:
name = n.name
output_shape = n.output_shape
shape = [str(s) for s in output_shape[1:]]
input_shapes[name] = ', '.join(shape)
input_shapes = ['"%s": [%s]' % (n, l) for n, l in input_shapes.items()]
shape_str = ','.join(input_shapes)
return '{%s}' % (shape_str)
def emit_main_def(self, name):
if name is None:
return ''
self.prefix = ''
main_def = self.statement('if __name__ == "__main__":')
self.indent()
main_def += self.statement('exit(main())')
return '\n\n' + main_def
def emit_parents(self, chain):
assert len(chain)
s = 'self.feed('
sep = ', \n' + self.prefix + (' ' * len(s))
s += sep.join(
["'%s'" % parent.name for parent in chain[0].node.parents])
return self.statement(s + ')')
def emit_node(self, node):
return self.statement('self.' + node.emit())
def emit(self, name, chains, input_nodes=None):
from ..net_template import generate_net_code
from ..net_template import generate_main_code
self.net_name = name
inputs_info = self.get_inputs_info(input_nodes)
s = self.emit_imports()
s += generate_net_code(name, inputs_info) + '\n'
self.indent()
# define the net using api
s += self.emit_setup_def()
self.indent()
blocks = []
for chain in chains:
b = ''
b += self.emit_parents(chain)
for node in chain:
b += self.emit_node(node)
blocks.append(b[:-1])
s = s + '\n\n'.join(blocks)
# define the main function
s += '\n\n\n' + generate_main_code(name)
s += self.emit_main_def(name)
return s
class Transformer(object):
def __init__(self, def_path, data_path, verbose=True, phase='test'):
self.verbose = verbose
self.phase = phase
self.load(def_path, data_path, phase)
self.params = None
self.source = None
def load(self, def_path, data_path, phase):
# Build the graph
graph = GraphBuilder(def_path, phase).build()
if data_path is not None:
# Load and associate learned parameters
graph = DataInjector(def_path, data_path)(graph)
# Transform the graph
transformers = [
# Fuse split batch normalization layers
BatchNormScaleBiasFuser(),
# Fuse ReLUs
# TODO: Move non-linearity application to layer wrapper, allowing
# any arbitrary operation to be optionally activated.
ReLUFuser(allowed_parent_types=[
NodeKind.Convolution, NodeKind.InnerProduct, NodeKind.BatchNorm
]),
# Rename nodes
# Slashes are used for scoping in Paddle. Replace slashes
# in node names with underscores.
# (Caffe's GoogLeNet implementation uses slashes)
NodeRenamer(lambda node: node.name.replace('/', '_')),
# Fuse Crop
# Crop is to return a scalar output Blob for an input Blob of arbitrary size.
# When one of the input Blob is "input" or "DummyData", we can remove this input Blob
# and put the shape into the reduction layer.
CropFuser()
]
self.graph = graph.transformed(transformers)
#for the purpose of recording name mapping because of fused nodes
trace = SubNodeFuser.traced_names()
chg2real = {}
deleted = {}
for k, v in trace.items():
chg2real[k] = v[-1] #mapping from changed-name to real-name
for n in v:
if n in chg2real:
continue
if n not in deleted:
deleted[n] = '%s.%s' % (k, v[-1])
self.graph.add_name_trace({
'chg2real': chg2real,
'deleted': deleted
}, 'paddle')
# Display the graph
if self.verbose:
print_stderr(self.graph)
def transform_data(self):
if self.params is None:
transformers = [
# Reshape the parameters to Paddle's ordering
DataReshaper({
# (c_o, c_i) -> (c_i, c_o)
NodeKind.InnerProduct: (1, 0)
}),
# Pre-process batch normalization data
BatchNormPreprocessor(),
# Convert parameters to dictionaries
ParameterNamer(),
]
self.graph = self.graph.transformed(transformers)
self.params = {
node.name: node.data
for node in self.graph.nodes if node.data
}
self.params['caffe2fluid_name_trace'] = self.graph.get_name_trace()
return self.params
def transform_source(self):
if self.source is None:
mapper = PaddleMapper(self.graph)
chains = mapper.map()
emitter = PaddleEmitter()
input_nodes = self.graph.get_input_nodes()
self.source = emitter.emit(self.graph.name, chains, input_nodes)
return self.source
"""a util for convert protobuf to dict
"""
from google.protobuf.message import Message
from google.protobuf.descriptor import FieldDescriptor
__all__ = [
"protobuf_to_dict", "TYPE_CALLABLE_MAP", "dict_to_protobuf",
"REVERSE_TYPE_CALLABLE_MAP"
]
EXTENSION_CONTAINER = '___X'
TYPE_CALLABLE_MAP = {
FieldDescriptor.TYPE_DOUBLE: float,
FieldDescriptor.TYPE_FLOAT: float,
FieldDescriptor.TYPE_INT32: int,
FieldDescriptor.TYPE_INT64: long,
FieldDescriptor.TYPE_UINT32: int,
FieldDescriptor.TYPE_UINT64: long,
FieldDescriptor.TYPE_SINT32: int,
FieldDescriptor.TYPE_SINT64: long,
FieldDescriptor.TYPE_FIXED32: int,
FieldDescriptor.TYPE_FIXED64: long,
FieldDescriptor.TYPE_SFIXED32: int,
FieldDescriptor.TYPE_SFIXED64: long,
FieldDescriptor.TYPE_BOOL: bool,
FieldDescriptor.TYPE_STRING: unicode,
FieldDescriptor.TYPE_BYTES: lambda b: b.encode("base64"),
FieldDescriptor.TYPE_ENUM: int,
}
def repeated(type_callable):
return lambda value_list: [type_callable(value) for value in value_list]
def enum_label_name(field, value):
return field.enum_type.values_by_number[int(value)].name
def protobuf_to_dict(pb,
type_callable_map=TYPE_CALLABLE_MAP,
use_enum_labels=False):
result_dict = {}
extensions = {}
for field, value in pb.ListFields():
type_callable = _get_field_value_adaptor(pb, field, type_callable_map,
use_enum_labels)
if field.label == FieldDescriptor.LABEL_REPEATED:
type_callable = repeated(type_callable)
if field.is_extension:
extensions[str(field.number)] = type_callable(value)
continue
result_dict[field.name] = type_callable(value)
if extensions:
result_dict[EXTENSION_CONTAINER] = extensions
return result_dict
def _get_field_value_adaptor(pb,
field,
type_callable_map=TYPE_CALLABLE_MAP,
use_enum_labels=False):
if field.type == FieldDescriptor.TYPE_MESSAGE:
# recursively encode protobuf sub-message
return lambda pb: protobuf_to_dict(pb,
type_callable_map=type_callable_map,
use_enum_labels=use_enum_labels)
if use_enum_labels and field.type == FieldDescriptor.TYPE_ENUM:
return lambda value: enum_label_name(field, value)
if field.type in type_callable_map:
return type_callable_map[field.type]
raise TypeError("Field %s.%s has unrecognised type id %d" %
(pb.__class__.__name__, field.name, field.type))
def get_bytes(value):
return value.decode('base64')
REVERSE_TYPE_CALLABLE_MAP = {FieldDescriptor.TYPE_BYTES: get_bytes, }
def dict_to_protobuf(pb_klass_or_instance,
values,
type_callable_map=REVERSE_TYPE_CALLABLE_MAP,
strict=True):
"""Populates a protobuf model from a dictionary.
:param pb_klass_or_instance: a protobuf message class, or an protobuf instance
:type pb_klass_or_instance: a type or instance of a subclass of google.protobuf.message.Message
:param dict values: a dictionary of values. Repeated and nested values are
fully supported.
:param dict type_callable_map: a mapping of protobuf types to callables for setting
values on the target instance.
:param bool strict: complain if keys in the map are not fields on the message.
"""
if isinstance(pb_klass_or_instance, Message):
instance = pb_klass_or_instance
else:
instance = pb_klass_or_instance()
return _dict_to_protobuf(instance, values, type_callable_map, strict)
def _get_field_mapping(pb, dict_value, strict):
field_mapping = []
for key, value in dict_value.items():
if key == EXTENSION_CONTAINER:
continue
if key not in pb.DESCRIPTOR.fields_by_name:
if strict:
raise KeyError("%s does not have a field called %s" % (pb, key))
continue
field_mapping.append(
(pb.DESCRIPTOR.fields_by_name[key], value, getattr(pb, key, None)))
for ext_num, ext_val in dict_value.get(EXTENSION_CONTAINER, {}).items():
try:
ext_num = int(ext_num)
except ValueError:
raise ValueError("Extension keys must be integers.")
if ext_num not in pb._extensions_by_number:
if strict:
raise KeyError(
"%s does not have a extension with number %s. Perhaps you forgot to import it?"
% (pb, key))
continue
ext_field = pb._extensions_by_number[ext_num]
pb_val = None
pb_val = pb.Extensions[ext_field]
field_mapping.append((ext_field, ext_val, pb_val))
return field_mapping
def _dict_to_protobuf(pb, value, type_callable_map, strict):
fields = _get_field_mapping(pb, value, strict)
for field, input_value, pb_value in fields:
if field.label == FieldDescriptor.LABEL_REPEATED:
for item in input_value:
if field.type == FieldDescriptor.TYPE_MESSAGE:
m = pb_value.add()
_dict_to_protobuf(m, item, type_callable_map, strict)
elif field.type == FieldDescriptor.TYPE_ENUM and isinstance(
item, basestring):
pb_value.append(_string_to_enum(field, item))
else:
pb_value.append(item)
continue
if field.type == FieldDescriptor.TYPE_MESSAGE:
_dict_to_protobuf(pb_value, input_value, type_callable_map, strict)
continue
if field.type in type_callable_map:
input_value = type_callable_map[field.type](input_value)
if field.is_extension:
pb.Extensions[field] = input_value
continue
if field.type == FieldDescriptor.TYPE_ENUM and isinstance(input_value,
basestring):
input_value = _string_to_enum(field, input_value)
setattr(pb, field.name, input_value)
return pb
def _string_to_enum(field, input_value):
enum_dict = field.enum_type.values_by_name
try:
input_value = enum_dict[input_value].number
except KeyError:
raise KeyError("`%s` is not a valid value for field `%s`" %
(input_value, field.name))
return input_value
import math
from collections import namedtuple
from .errors import KaffeError
Tensor4DShape = namedtuple('Tensor4DShape',
['batch_size', 'channels', 'height', 'width'])
Tensor3DShape = namedtuple('Tensor3DShape', ['batch_size', 'data1', 'data2'])
Tensor2DShape = namedtuple('Tensor2DShape', ['batch_size', 'data'])
ScalarShape = namedtuple('ScalarShape', ['batch_size'])
def make_tensor(batch_size, d1=None, d2=None, d3=None):
if d3 is not None:
return Tensor4DShape(batch_size, d1, d2, d3)
elif d1 is not None and d2 is not None:
return Tensor3DShape(batch_size, d1, d2)
elif d1 is not None and d2 is None:
return Tensor2DShape(batch_size, d1)
elif d1 is None and d2 is None and d3 is None:
return ScalarShape(batch_size)
else:
raise NotImplementedError('invalid params for make_tensor %s' \
% (str((batch_size, d1, d2, d3))))
def get_filter_output_shape(i_h, i_w, params, round_func):
dila_h = getattr(params, 'dila_h', 1)
dila_w = getattr(params, 'dila_w', 1)
o_h = (i_h + 2 * params.pad_h -
(dila_h * (params.kernel_h - 1) + 1)) / float(params.stride_h) + 1
o_w = (i_w + 2 * params.pad_w -
(dila_w * (params.kernel_w - 1) + 1)) / float(params.stride_w) + 1
return (int(round_func(o_h)), int(round_func(o_w)))
def get_strided_kernel_output_shape(node, round_func):
assert node.layer is not None
input_shape = node.get_only_parent().output_shape
o_h, o_w = get_filter_output_shape(input_shape.height, input_shape.width,
node.layer.kernel_parameters, round_func)
params = node.layer.parameters
has_c_o = hasattr(params, 'num_output')
c = params.num_output if has_c_o else input_shape.channels
return make_tensor(input_shape.batch_size, c, o_h, o_w)
def shape_not_implemented(node):
raise NotImplementedError
def shape_identity(node):
assert len(node.parents) > 0
return node.parents[0].output_shape
def shape_scalar(node):
return make_tensor(1, 1, 1, 1)
def shape_crop(node):
raise KaffeError('crop function had been defined in customer_layers')
def shape_power(node):
raise KaffeError('power function had been defined in customer_layers')
def shape_data(node):
if node.output_shape:
# Old-style input specification
shape = node.output_shape
else:
try:
# New-style input specification
shape = map(int, node.parameters.shape[0].dim)
except:
# We most likely have a data layer on our hands. The problem is,
# Caffe infers the dimensions of the data from the source (eg: LMDB).
# We want to avoid reading datasets here. Fail for now.
# This can be temporarily fixed by transforming the data layer to
# Caffe's "input" layer (as is usually used in the "deploy" version).
# TODO: Find a better solution for this.
raise KaffeError(
'Cannot determine dimensions of data layer.\n'
'See comments in function shape_data for more info.')
return shape
def shape_mem_data(node):
params = node.parameters
return make_tensor(params.batch_size, params.channels, params.height,
params.width)
def shape_concat(node):
axis = node.layer.parameters.axis
output_shape = None
for parent in node.parents:
if output_shape is None:
output_shape = list(parent.output_shape)
else:
output_shape[axis] += parent.output_shape[axis]
return tuple(output_shape)
def shape_convolution(node):
return get_strided_kernel_output_shape(node, math.floor)
def shape_deconvolution(node):
assert node.layer is not None
input_shape = node.get_only_parent().output_shape
h_i = input_shape.height
w_i = input_shape.width
params = node.layer.kernel_parameters
p_h = params.pad_h
p_w = params.pad_w
dila_h = params.dila_h
dila_w = params.dila_w
k_h = params.kernel_h
k_w = params.kernel_w
s_h = params.stride_h
s_w = params.stride_w
h_o = (h_i - 1) * s_h - 2 * p_h + dila_h * (k_h - 1) + 1
w_o = (w_i - 1) * s_w - 2 * p_w + dila_w * (k_w - 1) + 1
params = node.layer.parameters
has_c_o = hasattr(params, 'num_output')
c = params.num_output if has_c_o else input_shape.channels
return make_tensor(input_shape.batch_size, c, h_o, w_o)
def shape_pool(node):
global_pool = getattr(node.layer.parameters, 'global_pooling', False)
if global_pool:
input_shape = node.get_only_parent().output_shape
return make_tensor(input_shape.batch_size, input_shape.channels, 1, 1)
ceil_mode = getattr(node.layer.parameters, 'ceil_mode', True)
if ceil_mode is True:
method = math.ceil
else:
method = math.floor
return get_strided_kernel_output_shape(node, method)
def shape_inner_product(node):
input_shape = node.get_only_parent().output_shape
return make_tensor(input_shape.batch_size, node.layer.parameters.num_output)
'''
A collection of graph transforms.
A transformer is a callable that accepts a graph and returns a transformed version.
'''
import os
import numpy as np
from .caffe import get_caffe_resolver, has_pycaffe
from .errors import KaffeError, debug, notice, warn
from .layers import NodeKind
class DataInjector(object):
'''
Associates parameters loaded from a .caffemodel file with their corresponding nodes.
'''
def __init__(self, def_path, data_path):
# The .prototxt file defining the graph
self.def_path = def_path
# The .caffemodel file containing the learned parameters
self.data_path = data_path
# Set to true if the fallback protocol-buffer based backend was used
self.did_use_pb = False
# A list containing (layer name, parameters) tuples
self.params = None
# Load the parameters
self.load()
def load(self):
if has_pycaffe():
self.load_using_caffe()
else:
self.load_using_pb()
def load_using_caffe(self):
caffe = get_caffe_resolver().caffe
net = caffe.Net(self.def_path, self.data_path, caffe.TEST)
data = lambda blob: blob.data
self.params = [(k, map(data, v)) for k, v in net.params.items()]
def load_using_pb(self):
data = get_caffe_resolver().NetParameter()
data.MergeFromString(open(self.data_path, 'rb').read())
pair = lambda layer: (layer.name, self.normalize_pb_data(layer))
layers = data.layers or data.layer
self.params = [pair(layer) for layer in layers if layer.blobs]
self.did_use_pb = True
def normalize_pb_data(self, layer):
transformed = []
for blob in layer.blobs:
if len(blob.shape.dim):
dims = blob.shape.dim
c_o, c_i, h, w = map(int, [1] * (4 - len(dims)) + list(dims))
else:
c_o = blob.num
c_i = blob.channels
h = blob.height
w = blob.width
data = np.array(blob.data, dtype=np.float32).reshape(c_o, c_i, h, w)
transformed.append(data)
return transformed
def adjust_parameters(self, node, data):
if not self.did_use_pb:
return data
# When using the protobuf-backend, each parameter initially has four dimensions.
# In certain cases (like FC layers), we want to eliminate the singleton dimensions.
# This implementation takes care of the common cases. However, it does leave the
# potential for future issues.
# The Caffe-backend does not suffer from this problem.
data = list(data)
squeeze_indices = [1] # Squeeze biases.
if node.kind == NodeKind.InnerProduct:
squeeze_indices.append(0) # Squeeze FC.
for idx in squeeze_indices:
if idx >= len(data):
continue
d = data[idx]
assert len(
d.shape
) == 4, 'invalid shape[%s] from caffe when adjust_parameters' % (
str(d.shape))
shape_old = d.shape
sq_axis = None
if idx == 0:
sq_axis = (0, 1)
elif idx == 1:
sq_axis = (0, 1, 2)
else:
continue
data[idx] = np.squeeze(d, axis=sq_axis)
shape_new = data[idx].shape
if len(shape_old) != shape_new:
debug('squeeze idx:%d, with kind:%s,name:%s' % \
(idx, node.kind, node.name))
return data
def __call__(self, graph):
for layer_name, data in self.params:
if layer_name in graph:
node = graph.get_node(layer_name)
node.data = self.adjust_parameters(node, data)
else:
notice('Ignoring parameters for non-existent layer: %s' % \
layer_name)
return graph
class DataReshaper(object):
def __init__(self, mapping, replace=True):
# A dictionary mapping NodeKind to the transposed order.
self.mapping = mapping
# The node kinds eligible for reshaping
self.reshaped_node_types = self.mapping.keys()
# If true, the reshaped data will replace the old one.
# Otherwise, it's set to the reshaped_data attribute.
self.replace = replace
def has_spatial_parent(self, node):
try:
parent = node.get_only_parent()
s = parent.output_shape
if len(s) == 4:
return s.height > 1 or s.width > 1
else:
return False
except KaffeError:
return False
def map(self, node_kind):
try:
return self.mapping[node_kind]
except KeyError:
raise KaffeError('Ordering not found for node kind: {}'.format(
node_kind))
def __call__(self, graph):
for node in graph.nodes:
if node.data is None:
continue
if node.kind not in self.reshaped_node_types:
# Check for 2+ dimensional data
#if any(len(tensor.shape) > 1 for tensor in node.data):
# notice('parmaters not reshaped for node: {}'.format(node))
continue
transpose_order = self.map(node.kind)
weights = node.data[0]
if node.kind == NodeKind.InnerProduct:
# The FC layer connected to the spatial layer needs to be
# re-wired to match the new spatial ordering.
#in_shape = node.get_only_parent().output_shape
fc_shape = weights.shape
output_channels = fc_shape[0]
weights = weights.reshape((output_channels, -1))
weights = weights.transpose(transpose_order)
node.reshaped_data = weights
else:
node.reshaped_data = weights.transpose(transpose_order)
if self.replace:
for node in graph.nodes:
if hasattr(node, 'reshaped_data'):
# Set the weights
node.data[0] = node.reshaped_data
del node.reshaped_data
return graph
class CropFuser(object):
'''
Crop is to return a scalar output Blob for an input Blob of arbitrary size.
When one of the input Blob is "input" or "DummyData", we can remove the input Blob
and put the shape into the reduction layer.
'''
_traced_names = {}
@classmethod
def traced_names(cls):
return cls._traced_names
@classmethod
def trace(cls, fname, tname):
""" recording the names mapping,
the value of 'fname' will be replaced by value of 'tname'
"""
if fname not in cls._traced_names:
cls._traced_names[fname] = []
cls._traced_names[fname].append(tname)
def __init__(self,
allowed_parent_types=[NodeKind.Input, NodeKind.DummyData]):
self.allowed_parent_types = allowed_parent_types
def __call__(self, graph):
nodes = graph.nodes
fused_nodes = []
for node in nodes:
if len(node.parents) != 2:
# reduction layer must has two parent layers.
continue
parent = node.parents[1]
if not self.is_eligible_pair(parent, node):
continue
# Change the graph structure.
parent.children.remove(node)
node.parents.remove(parent)
# Let the sub-class merge the fused node in any arbitrary way.
if not len(parent.children):
fused_nodes.append(parent)
#fused_nodes.append(parent)
self.merge(parent, node)
# rebuild the graph
transformed_nodes = [node for node in nodes if node not in fused_nodes]
return graph.replaced(transformed_nodes)
def is_eligible_pair(self, parent, child):
'''Returns true if this parent/child pair is eligible for fusion.'''
return child.kind == NodeKind.Crop
#return (self.allowed_parent_types is not None and \
# len(parent.children) == 1 and \
# parent.kind in self.allowed_parent_types and \
# child.kind == NodeKind.Crop)
def merge(self, parent, child):
'''Merge the parent node into the child.'''
child.metadata['shape'] = [
parent.output_shape.batch_size, parent.output_shape.channels,
parent.output_shape.height, parent.output_shape.width
]
class SubNodeFuser(object):
'''
An abstract helper for merging a single-child with its single-parent.
'''
_traced_names = {}
@classmethod
def traced_names(cls):
return cls._traced_names
@classmethod
def trace(cls, fname, tname):
""" recording the names mapping,
the value of 'fname' will be replaced by value of 'tname'
"""
if fname not in cls._traced_names:
cls._traced_names[fname] = []
cls._traced_names[fname].append(tname)
def __call__(self, graph):
nodes = graph.nodes
fused_nodes = []
for node in nodes:
if len(node.parents) != 1:
# We're only fusing nodes with single parents
continue
parent = node.get_only_parent()
if len(parent.children) != 1:
# We can only fuse a node if its parent's
# value isn't used by any other node.
continue
if not self.is_eligible_pair(parent, node):
continue
# Rewrite the fused node's children to its parent.
for child in node.children:
pos = child.parents.index(node)
child.parents[pos] = parent
parent.add_child(child)
# Disconnect the fused node from the graph.
parent.children.remove(node)
fused_nodes.append(node)
# Let the sub-class merge the fused node in any arbitrary way.
self.merge(parent, node)
transformed_nodes = [node for node in nodes if node not in fused_nodes]
return graph.replaced(transformed_nodes)
def is_eligible_pair(self, parent, child):
'''Returns true if this parent/child pair is eligible for fusion.'''
raise NotImplementedError('Must be implemented by subclass.')
def merge(self, parent, child):
'''Merge the child node into the parent.'''
raise NotImplementedError('Must be implemented by subclass')
class ReLUFuser(SubNodeFuser):
'''
Fuses rectified linear units with their parent nodes.
'''
def __init__(self, allowed_parent_types=None):
# Fuse ReLUs when the parent node is one of the given types.
# If None, all node types are eligible.
self.allowed_parent_types = allowed_parent_types
def is_eligible_pair(self, parent, child):
return ((self.allowed_parent_types is None or \
parent.kind in self.allowed_parent_types) and \
child.kind == NodeKind.ReLU)
def merge(self, parent, child):
SubNodeFuser.trace(parent.name, child.name)
parent.metadata['relu'] = True
parent.metadata['relu_negative_slope'] = child.parameters.negative_slope
class BatchNormScaleBiasFuser(SubNodeFuser):
'''
The original batch normalization paper includes two learned
parameters: a scaling factor \gamma and a bias \beta.
Caffe's implementation does not include these two. However, it is commonly
replicated by adding a scaling+bias layer immidiately after the batch norm.
This fuser merges the scaling+bias layer with the batch norm.
'''
def is_eligible_pair(self, parent, child):
return (parent.kind == NodeKind.BatchNorm and \
child.kind == NodeKind.Scale and \
child.parameters.axis == 1 and \
child.parameters.bias_term == True)
def merge(self, parent, child):
SubNodeFuser.trace(parent.name, child.name)
parent.scale_bias_node = child
class BatchNormPreprocessor(object):
'''
Prescale batch normalization parameters.
Concatenate gamma (scale) and beta (bias) terms if set.
'''
def __call__(self, graph):
for node in graph.nodes:
if node.kind != NodeKind.BatchNorm:
continue
assert node.data is not None
assert len(node.data) == 3
node.data = [np.squeeze(i) for i in node.data]
mean, variance, scale = node.data
# Prescale the stats
scaling_factor = 1.0 / scale if scale != 0 else 0
mean *= scaling_factor
variance *= scaling_factor
# Replace with the updated values
node.data = [mean, variance]
if hasattr(node, 'scale_bias_node'):
# Include the scale and bias terms
gamma, beta = node.scale_bias_node.data
node.data += [np.squeeze(i) for i in [gamma, beta]]
return graph
class NodeRenamer(object):
'''
Renames nodes in the graph using a given unary function that
accepts a node and returns its new name.
'''
def __init__(self, renamer):
self.renamer = renamer
def __call__(self, graph):
for node in graph.nodes:
node.name = self.renamer(node)
return graph
class ParameterNamer(object):
'''
Convert layer data arrays to a dictionary mapping parameter names to their values.
'''
def __call__(self, graph):
for node in graph.nodes:
if node.data is None:
continue
if node.kind in (NodeKind.Convolution, NodeKind.InnerProduct,\
NodeKind.Deconvolution):
names = ('weights', )
if node.parameters.bias_term:
names += ('biases', )
elif node.kind == NodeKind.BatchNorm:
names = ('mean', 'variance')
if len(node.data) == 4:
names += ('scale', 'offset')
elif node.kind == NodeKind.Scale:
names = ('scale', )
if getattr(node.parameters, 'bias_term', False):
names = ('scale', 'offset')
elif node.kind == NodeKind.PReLU:
names = ('negslope', )
elif node.kind == "Normalize":
names = ('scale', )
else:
warn('Unhandled parameters when naming this it[%s]' %
(node.kind))
continue
assert len(names) == len(node.data)
node.data = dict(zip(names, node.data))
return graph
syntax = "proto2";
package caffe;
// Specifies the shape (dimensions) of a Blob.
message BlobShape { repeated int64 dim = 1 [ packed = true ]; }
message BlobProto {
optional BlobShape shape = 7;
repeated float data = 5 [ packed = true ];
repeated float diff = 6 [ packed = true ];
repeated double double_data = 8 [ packed = true ];
repeated double double_diff = 9 [ packed = true ];
// 4D dimensions -- deprecated. Use "shape" instead.
optional int32 num = 1 [ default = 0 ];
optional int32 channels = 2 [ default = 0 ];
optional int32 height = 3 [ default = 0 ];
optional int32 width = 4 [ default = 0 ];
}
// The BlobProtoVector is simply a way to pass multiple blobproto instances
// around.
message BlobProtoVector { repeated BlobProto blobs = 1; }
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
// the actual image data, in bytes
optional bytes data = 4;
optional int32 label = 5;
// Optionally, the datum could also hold float data.
repeated float float_data = 6;
// If true data contains an encoded image that need to be decoded
optional bool encoded = 7 [ default = false ];
}
message FillerParameter {
// The filler type.
optional string type = 1 [ default = 'constant' ];
optional float value = 2 [ default = 0 ]; // the value in constant filler
optional float min = 3 [ default = 0 ]; // the min value in uniform filler
optional float max = 4 [ default = 1 ]; // the max value in uniform filler
optional float mean = 5 [ default = 0 ]; // the mean value in Gaussian filler
optional float std = 6 [ default = 1 ]; // the std value in Gaussian filler
// The expected number of non-zero output weights for a given input in
// Gaussian filler -- the default -1 means don't perform sparsification.
optional int32 sparse = 7 [ default = -1 ];
// Normalize the filler variance by fan_in, fan_out, or their average.
// Applies to 'xavier' and 'msra' fillers.
enum VarianceNorm {
FAN_IN = 0;
FAN_OUT = 1;
AVERAGE = 2;
}
optional VarianceNorm variance_norm = 8 [ default = FAN_IN ];
}
message NetParameter {
optional string name = 1; // consider giving the network a name
// DEPRECATED. See InputParameter. The input blobs to the network.
repeated string input = 3;
// DEPRECATED. See InputParameter. The shape of the input blobs.
repeated BlobShape input_shape = 8;
// 4D input dimensions -- deprecated. Use "input_shape" instead.
// If specified, for each input blob there should be four
// values specifying the num, channels, height and width of the input blob.
// Thus, there should be a total of (4 * #input) numbers.
repeated int32 input_dim = 4;
// Whether the network will force every layer to carry out backward operation.
// If set False, then whether to carry out backward is determined
// automatically according to the net structure and learning rates.
optional bool force_backward = 5 [ default = false ];
// The current "state" of the network, including the phase, level, and stage.
// Some layers may be included/excluded depending on this state and the states
// specified in the layers' include and exclude fields.
optional NetState state = 6;
// Print debugging information about results while running Net::Forward,
// Net::Backward, and Net::Update.
optional bool debug_info = 7 [ default = false ];
// The layers that make up the net. Each of their configurations, including
// connectivity and behavior, is specified as a LayerParameter.
repeated LayerParameter layer = 100; // ID 100 so layers are printed last.
// DEPRECATED: use 'layer' instead.
repeated V1LayerParameter layers = 2;
}
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 42 (last added: layer_wise_reduce)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
//
// Exactly one train net must be specified using one of the following fields:
// train_net_param, train_net, net_param, net
// One or more test nets may be specified using any of the following fields:
// test_net_param, test_net, net_param, net
// If more than one test net field is specified (e.g., both net and
// test_net are specified), they will be evaluated in the field order given
// above: (1) test_net_param, (2) test_net, (3) net_param/net.
// A test_iter must be specified for each test_net.
// A test_level and/or a test_stage may also be specified for each test_net.
//////////////////////////////////////////////////////////////////////////////
// Proto filename for the train net, possibly combined with one or more
// test nets.
optional string net = 24;
// Inline train net param, possibly combined with one or more test nets.
optional NetParameter net_param = 25;
optional string train_net = 1; // Proto filename for the train net.
repeated string test_net = 2; // Proto filenames for the test nets.
optional NetParameter train_net_param = 21; // Inline train net params.
repeated NetParameter test_net_param = 22; // Inline test net params.
// The states for the train/test nets. Must be unspecified or
// specified once per net.
//
// By default, train_state will have phase = TRAIN,
// and all test_state's will have phase = TEST.
// Other defaults are set according to the NetState defaults.
optional NetState train_state = 26;
repeated NetState test_state = 27;
// The number of iterations for each test net.
repeated int32 test_iter = 3;
// The number of iterations between two testing phases.
optional int32 test_interval = 4 [ default = 0 ];
optional bool test_compute_loss = 19 [ default = false ];
// If true, run an initial test pass before the first iteration,
// ensuring memory availability and printing the starting value of the loss.
optional bool test_initialization = 32 [ default = true ];
optional float base_lr = 5; // The base learning rate
// the number of iterations between displaying info. If display = 0, no info
// will be displayed.
optional int32 display = 6;
// Display the loss averaged over the last average_loss iterations
optional int32 average_loss = 33 [ default = 1 ];
optional int32 max_iter = 7; // the maximum number of iterations
// accumulate gradients over `iter_size` x `batch_size` instances
optional int32 iter_size = 36 [ default = 1 ];
// The learning rate decay policy. The currently implemented learning rate
// policies are as follows:
// - fixed: always return base_lr.
// - step: return base_lr * gamma ^ (floor(iter / step))
// - exp: return base_lr * gamma ^ iter
// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
// - multistep: similar to step but it allows non uniform steps defined by
// stepvalue
// - poly: the effective learning rate follows a polynomial decay, to be
// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
// - sigmoid: the effective learning rate follows a sigmod decay
// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
//
// where base_lr, max_iter, gamma, step, stepvalue and power are defined
// in the solver parameter protocol buffer, and iter is the current iteration.
optional string lr_policy = 8;
optional float gamma = 9; // The parameter to compute the learning rate.
optional float power = 10; // The parameter to compute the learning rate.
optional float momentum = 11; // The momentum value.
optional float weight_decay = 12; // The weight decay.
// regularization types supported: L1 and L2
// controlled by weight_decay
optional string regularization_type = 29 [ default = "L2" ];
// the stepsize for learning rate policy "step"
optional int32 stepsize = 13;
// the stepsize for learning rate policy "multistep"
repeated int32 stepvalue = 34;
// Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm,
// whenever their actual L2 norm is larger.
optional float clip_gradients = 35 [ default = -1 ];
optional int32 snapshot = 14 [ default = 0 ]; // The snapshot interval
optional string snapshot_prefix = 15; // The prefix for the snapshot.
// whether to snapshot diff in the results or not. Snapshotting diff will help
// debugging but the final protocol buffer size will be much larger.
optional bool snapshot_diff = 16 [ default = false ];
enum SnapshotFormat {
HDF5 = 0;
BINARYPROTO = 1;
}
optional SnapshotFormat snapshot_format = 37 [ default = BINARYPROTO ];
// the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default.
enum SolverMode {
CPU = 0;
GPU = 1;
}
optional SolverMode solver_mode = 17 [ default = GPU ];
// the device_id will that be used in GPU mode. Use device_id = 0 in default.
optional int32 device_id = 18 [ default = 0 ];
// If non-negative, the seed with which the Solver will initialize the Caffe
// random number generator -- useful for reproducible results. Otherwise,
// (and by default) initialize using a seed derived from the system clock.
optional int64 random_seed = 20 [ default = -1 ];
// type of the solver
optional string type = 40 [ default = "SGD" ];
// numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
optional float delta = 31 [ default = 1e-8 ];
// parameters for the Adam solver
optional float momentum2 = 39 [ default = 0.999 ];
// RMSProp decay value
// MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
optional float rms_decay = 38 [ default = 0.99 ];
// If true, print information about the state of the net that may help with
// debugging learning problems.
optional bool debug_info = 23 [ default = false ];
// If false, don't save a snapshot after training finishes.
optional bool snapshot_after_train = 28 [ default = true ];
// DEPRECATED: old solver enum types, use string instead
enum SolverType {
SGD = 0;
NESTEROV = 1;
ADAGRAD = 2;
RMSPROP = 3;
ADADELTA = 4;
ADAM = 5;
}
// DEPRECATED: use type instead of solver_type
optional SolverType solver_type = 30 [ default = SGD ];
// Overlap compute and communication for data parallel training
optional bool layer_wise_reduce = 41 [ default = true ];
}
// A message that stores the solver snapshots
message SolverState {
optional int32 iter = 1; // The current iteration
optional string learned_net = 2; // The file that stores the learned net.
repeated BlobProto history = 3; // The history for sgd solvers
optional int32 current_step = 4
[ default = 0 ]; // The current step for learning rate
}
enum Phase {
TRAIN = 0;
TEST = 1;
}
message NetState {
optional Phase phase = 1 [ default = TEST ];
optional int32 level = 2 [ default = 0 ];
repeated string stage = 3;
}
message NetStateRule {
// Set phase to require the NetState have a particular phase (TRAIN or TEST)
// to meet this rule.
optional Phase phase = 1;
// Set the minimum and/or maximum levels in which the layer should be used.
// Leave undefined to meet the rule regardless of level.
optional int32 min_level = 2;
optional int32 max_level = 3;
// Customizable sets of stages to include or exclude.
// The net must have ALL of the specified stages and NONE of the specified
// "not_stage"s to meet the rule.
// (Use multiple NetStateRules to specify conjunctions of stages.)
repeated string stage = 4;
repeated string not_stage = 5;
}
// Specifies training parameters (multipliers on global learning constants,
// and the name and other settings used for weight sharing).
message ParamSpec {
// The names of the parameter blobs -- useful for sharing parameters among
// layers, but never required otherwise. To share a parameter between two
// layers, give it a (non-empty) name.
optional string name = 1;
// Whether to require shared weights to have the same shape, or just the same
// count -- defaults to STRICT if unspecified.
optional DimCheckMode share_mode = 2;
enum DimCheckMode {
// STRICT (default) requires that num, channels, height, width each match.
STRICT = 0;
// PERMISSIVE requires only the count (num*channels*height*width) to match.
PERMISSIVE = 1;
}
// The multiplier on the global learning rate for this parameter.
optional float lr_mult = 3 [ default = 1.0 ];
// The multiplier on the global weight decay for this parameter.
optional float decay_mult = 4 [ default = 1.0 ];
}
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available layer-specific ID: 147 (last added:
// recurrent_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
repeated string bottom = 3; // the name of each bottom blob
repeated string top = 4; // the name of each top blob
// The train / test phase for computation.
optional Phase phase = 10;
// The amount of weight to assign each top blob in the objective.
// Each layer assigns a default value, usually of either 0 or 1,
// to each top blob.
repeated float loss_weight = 5;
// Specifies training parameters (multipliers on global learning constants,
// and the name and other settings used for weight sharing).
repeated ParamSpec param = 6;
// The blobs containing the numeric parameters of the layer.
repeated BlobProto blobs = 7;
// Specifies whether to backpropagate to each bottom. If unspecified,
// Caffe will automatically infer whether each input needs backpropagation
// to compute parameter gradients. If set to true for some inputs,
// backpropagation to those inputs is forced; if set false for some inputs,
// backpropagation to those inputs is skipped.
//
// The size must be either 0 or equal to the number of bottoms.
repeated bool propagate_down = 11;
// Rules controlling whether and when a layer is included in the network,
// based on the current NetState. You may specify a non-zero number of rules
// to include OR exclude, but not both. If no include or exclude rules are
// specified, the layer is always included. If the current NetState meets
// ANY (i.e., one or more) of the specified rules, the layer is
// included/excluded.
repeated NetStateRule include = 8;
repeated NetStateRule exclude = 9;
// Parameters for data pre-processing.
optional TransformationParameter transform_param = 100;
// Parameters shared by loss layers.
optional LossParameter loss_param = 101;
// Layer type-specific parameters.
//
// Note: certain layers may have more than one computational engine
// for their implementation. These layers include an Engine type and
// engine parameter for selecting the implementation.
// The default for the engine is set by the ENGINE switch at compile-time.
optional AccuracyParameter accuracy_param = 102;
optional ArgMaxParameter argmax_param = 103;
optional BatchNormParameter batch_norm_param = 139;
optional BiasParameter bias_param = 141;
optional ConcatParameter concat_param = 104;
optional ContrastiveLossParameter contrastive_loss_param = 105;
optional ConvolutionParameter convolution_param = 106;
optional CropParameter crop_param = 144;
optional DataParameter data_param = 107;
optional DropoutParameter dropout_param = 108;
optional DummyDataParameter dummy_data_param = 109;
optional EltwiseParameter eltwise_param = 110;
optional ELUParameter elu_param = 140;
optional EmbedParameter embed_param = 137;
optional ExpParameter exp_param = 111;
optional FlattenParameter flatten_param = 135;
optional HDF5DataParameter hdf5_data_param = 112;
optional HDF5OutputParameter hdf5_output_param = 113;
optional HingeLossParameter hinge_loss_param = 114;
optional ImageDataParameter image_data_param = 115;
optional InfogainLossParameter infogain_loss_param = 116;
optional InnerProductParameter inner_product_param = 117;
optional InputParameter input_param = 143;
optional LogParameter log_param = 134;
optional LRNParameter lrn_param = 118;
optional MemoryDataParameter memory_data_param = 119;
optional MVNParameter mvn_param = 120;
optional ParameterParameter parameter_param = 145;
optional PoolingParameter pooling_param = 121;
optional PowerParameter power_param = 122;
optional PReLUParameter prelu_param = 131;
optional PythonParameter python_param = 130;
optional RecurrentParameter recurrent_param = 146;
optional ReductionParameter reduction_param = 136;
optional ReLUParameter relu_param = 123;
optional ReshapeParameter reshape_param = 133;
optional ScaleParameter scale_param = 142;
optional SigmoidParameter sigmoid_param = 124;
optional SoftmaxParameter softmax_param = 125;
optional SPPParameter spp_param = 132;
optional SliceParameter slice_param = 126;
optional TanHParameter tanh_param = 127;
optional ThresholdParameter threshold_param = 128;
optional TileParameter tile_param = 138;
optional WindowDataParameter window_data_param = 129;
}
// Message that stores parameters used to apply transformation
// to the data layer's data
message TransformationParameter {
// For data pre-processing, we can do simple scaling and subtracting the
// data mean, if provided. Note that the mean subtraction is always carried
// out before scaling.
optional float scale = 1 [ default = 1 ];
// Specify if we want to randomly mirror data.
optional bool mirror = 2 [ default = false ];
// Specify if we would like to randomly crop an image.
optional uint32 crop_size = 3 [ default = 0 ];
// mean_file and mean_value cannot be specified at the same time
optional string mean_file = 4;
// if specified can be repeated once (would subtract it from all the channels)
// or can be repeated the same number of times as channels
// (would subtract them from the corresponding channel)
repeated float mean_value = 5;
// Force the decoded image to have 3 color channels.
optional bool force_color = 6 [ default = false ];
// Force the decoded image to have 1 color channels.
optional bool force_gray = 7 [ default = false ];
}
// Message that stores parameters shared by loss layers
message LossParameter {
// If specified, ignore instances with the given label.
optional int32 ignore_label = 1;
// How to normalize the loss for loss layers that aggregate across batches,
// spatial dimensions, or other dimensions. Currently only implemented in
// SoftmaxWithLoss and SigmoidCrossEntropyLoss layers.
enum NormalizationMode {
// Divide by the number of examples in the batch times spatial dimensions.
// Outputs that receive the ignore label will NOT be ignored in computing
// the normalization factor.
FULL = 0;
// Divide by the total number of output locations that do not take the
// ignore_label. If ignore_label is not set, this behaves like FULL.
VALID = 1;
// Divide by the batch size.
BATCH_SIZE = 2;
// Do not normalize the loss.
NONE = 3;
}
// For historical reasons, the default normalization for
// SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID.
optional NormalizationMode normalization = 3 [ default = VALID ];
// Deprecated. Ignored if normalization is specified. If normalization
// is not specified, then setting this to false will be equivalent to
// normalization = BATCH_SIZE to be consistent with previous behavior.
optional bool normalize = 2;
}
// Messages that store parameters used by individual layer types follow, in
// alphabetical order.
message AccuracyParameter {
// When computing accuracy, count as correct by comparing the true label to
// the top k scoring classes. By default, only compare to the top scoring
// class (i.e. argmax).
optional uint32 top_k = 1 [ default = 1 ];
// The "label" axis of the prediction blob, whose argmax corresponds to the
// predicted label -- may be negative to index from the end (e.g., -1 for the
// last axis). For example, if axis == 1 and the predictions are
// (N x C x H x W), the label blob is expected to contain N*H*W ground truth
// labels with integer values in {0, 1, ..., C-1}.
optional int32 axis = 2 [ default = 1 ];
// If specified, ignore instances with the given label.
optional int32 ignore_label = 3;
}
message ArgMaxParameter {
// If true produce pairs (argmax, maxval)
optional bool out_max_val = 1 [ default = false ];
optional uint32 top_k = 2 [ default = 1 ];
// The axis along which to maximise -- may be negative to index from the
// end (e.g., -1 for the last axis).
// By default ArgMaxLayer maximizes over the flattened trailing dimensions
// for each index of the first / num dimension.
optional int32 axis = 3;
}
message ConcatParameter {
// The axis along which to concatenate -- may be negative to index from the
// end (e.g., -1 for the last axis). Other axes must have the
// same dimension for all the bottom blobs.
// By default, ConcatLayer concatenates blobs along the "channels" axis (1).
optional int32 axis = 2 [ default = 1 ];
// DEPRECATED: alias for "axis" -- does not support negative indexing.
optional uint32 concat_dim = 1 [ default = 1 ];
}
message BatchNormParameter {
// If false, normalization is performed over the current mini-batch
// and global statistics are accumulated (but not yet used) by a moving
// average.
// If true, those accumulated mean and variance values are used for the
// normalization.
// By default, it is set to false when the network is in the training
// phase and true when the network is in the testing phase.
optional bool use_global_stats = 1;
// What fraction of the moving average remains each iteration?
// Smaller values make the moving average decay faster, giving more
// weight to the recent values.
// Each iteration updates the moving average @f$S_{t-1}@f$ with the
// current mean @f$ Y_t @f$ by
// @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$
// is the moving_average_fraction parameter.
optional float moving_average_fraction = 2 [ default = .999 ];
// Small value to add to the variance estimate so that we don't divide by
// zero.
optional float eps = 3 [ default = 1e-5 ];
}
message BiasParameter {
// The first axis of bottom[0] (the first input Blob) along which to apply
// bottom[1] (the second input Blob). May be negative to index from the end
// (e.g., -1 for the last axis).
//
// For example, if bottom[0] is 4D with shape 100x3x40x60, the output
// top[0] will have the same shape, and bottom[1] may have any of the
// following shapes (for the given value of axis):
// (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60
// (axis == 1 == -3) 3; 3x40; 3x40x60
// (axis == 2 == -2) 40; 40x60
// (axis == 3 == -1) 60
// Furthermore, bottom[1] may have the empty shape (regardless of the value of
// "axis") -- a scalar bias.
optional int32 axis = 1 [ default = 1 ];
// (num_axes is ignored unless just one bottom is given and the bias is
// a learned parameter of the layer. Otherwise, num_axes is determined by the
// number of axes by the second bottom.)
// The number of axes of the input (bottom[0]) covered by the bias
// parameter, or -1 to cover all axes of bottom[0] starting from `axis`.
// Set num_axes := 0, to add a zero-axis Blob: a scalar.
optional int32 num_axes = 2 [ default = 1 ];
// (filler is ignored unless just one bottom is given and the bias is
// a learned parameter of the layer.)
// The initialization for the learned bias parameter.
// Default is the zero (0) initialization, resulting in the BiasLayer
// initially performing the identity operation.
optional FillerParameter filler = 3;
}
message ContrastiveLossParameter {
// margin for dissimilar pair
optional float margin = 1 [ default = 1.0 ];
// The first implementation of this cost did not exactly match the cost of
// Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2.
// legacy_version = false (the default) uses (margin - d)^2 as proposed in the
// Hadsell paper. New models should probably use this version.
// legacy_version = true uses (margin - d^2). This is kept to support /
// reproduce existing models and results
optional bool legacy_version = 2 [ default = false ];
}
message ConvolutionParameter {
optional uint32 num_output = 1; // The number of outputs for the layer
optional bool bias_term = 2 [ default = true ]; // whether to have bias terms
// Pad, kernel size, and stride are all given as a single value for equal
// dimensions in all spatial dimensions, or once per spatial dimension.
repeated uint32 pad = 3; // The padding size; defaults to 0
repeated uint32 kernel_size = 4; // The kernel size
repeated uint32 stride = 6; // The stride; defaults to 1
// Factor used to dilate the kernel, (implicitly) zero-filling the resulting
// holes. (Kernel dilation is sometimes referred to by its use in the
// algorithme à trous from Holschneider et al. 1987.)
repeated uint32 dilation = 18; // The dilation; defaults to 1
// For 2D convolution only, the *_h and *_w versions may also be used to
// specify both spatial dimensions.
optional uint32 pad_h = 9 [ default = 0 ]; // The padding height (2D only)
optional uint32 pad_w = 10 [ default = 0 ]; // The padding width (2D only)
optional uint32 kernel_h = 11; // The kernel height (2D only)
optional uint32 kernel_w = 12; // The kernel width (2D only)
optional uint32 stride_h = 13; // The stride height (2D only)
optional uint32 stride_w = 14; // The stride width (2D only)
optional uint32 group = 5 [ default = 1 ]; // The group size for group conv
optional FillerParameter weight_filler = 7; // The filler for the weight
optional FillerParameter bias_filler = 8; // The filler for the bias
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 15 [ default = DEFAULT ];
// The axis to interpret as "channels" when performing convolution.
// Preceding dimensions are treated as independent inputs;
// succeeding dimensions are treated as "spatial".
// With (N, C, H, W) inputs, and axis == 1 (the default), we perform
// N independent 2D convolutions, sliding C-channel (or (C/g)-channels, for
// groups g>1) filters across the spatial axes (H, W) of the input.
// With (N, C, D, H, W) inputs, and axis == 1, we perform
// N independent 3D convolutions, sliding (C/g)-channels
// filters across the spatial axes (D, H, W) of the input.
optional int32 axis = 16 [ default = 1 ];
// Whether to force use of the general ND convolution, even if a specific
// implementation for blobs of the appropriate number of spatial dimensions
// is available. (Currently, there is only a 2D-specific convolution
// implementation; for input blobs with num_axes != 2, this option is
// ignored and the ND implementation will be used.)
optional bool force_nd_im2col = 17 [ default = false ];
}
message CropParameter {
// To crop, elements of the first bottom are selected to fit the dimensions
// of the second, reference bottom. The crop is configured by
// - the crop `axis` to pick the dimensions for cropping
// - the crop `offset` to set the shift for all/each dimension
// to align the cropped bottom with the reference bottom.
// All dimensions up to but excluding `axis` are preserved, while
// the dimensions including and trailing `axis` are cropped.
// If only one `offset` is set, then all dimensions are offset by this amount.
// Otherwise, the number of offsets must equal the number of cropped axes to
// shift the crop in each dimension accordingly.
// Note: standard dimensions are N,C,H,W so the default is a spatial crop,
// and `axis` may be negative to index from the end (e.g., -1 for the last
// axis).
optional int32 axis = 1 [ default = 2 ];
repeated uint32 offset = 2;
}
message DataParameter {
enum DB {
LEVELDB = 0;
LMDB = 1;
}
// Specify the data source.
optional string source = 1;
// Specify the batch size.
optional uint32 batch_size = 4;
// The rand_skip variable is for the data layer to skip a few data points
// to avoid all asynchronous sgd clients to start at the same point. The skip
// point would be set as rand_skip * rand(0,1). Note that rand_skip should not
// be larger than the number of keys in the database.
// DEPRECATED. Each solver accesses a different subset of the database.
optional uint32 rand_skip = 7 [ default = 0 ];
optional DB backend = 8 [ default = LEVELDB ];
// DEPRECATED. See TransformationParameter. For data pre-processing, we can do
// simple scaling and subtracting the data mean, if provided. Note that the
// mean subtraction is always carried out before scaling.
optional float scale = 2 [ default = 1 ];
optional string mean_file = 3;
// DEPRECATED. See TransformationParameter. Specify if we would like to
// randomly
// crop an image.
optional uint32 crop_size = 5 [ default = 0 ];
// DEPRECATED. See TransformationParameter. Specify if we want to randomly
// mirror
// data.
optional bool mirror = 6 [ default = false ];
// Force the encoded image to have 3 color channels
optional bool force_encoded_color = 9 [ default = false ];
// Prefetch queue (Increase if data feeding bandwidth varies, within the
// limit of device memory for GPU training)
optional uint32 prefetch = 10 [ default = 4 ];
}
message DropoutParameter {
optional float dropout_ratio = 1 [ default = 0.5 ]; // dropout ratio
}
// DummyDataLayer fills any number of arbitrarily shaped blobs with random
// (or constant) data generated by "Fillers" (see "message FillerParameter").
message DummyDataParameter {
// This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or
// N
// shape fields, and 0, 1 or N data_fillers.
//
// If 0 data_fillers are specified, ConstantFiller with a value of 0 is used.
// If 1 data_filler is specified, it is applied to all top blobs. If N are
// specified, the ith is applied to the ith top blob.
repeated FillerParameter data_filler = 1;
repeated BlobShape shape = 6;
// 4D dimensions -- deprecated. Use "shape" instead.
repeated uint32 num = 2;
repeated uint32 channels = 3;
repeated uint32 height = 4;
repeated uint32 width = 5;
}
message EltwiseParameter {
enum EltwiseOp {
PROD = 0;
SUM = 1;
MAX = 2;
}
optional EltwiseOp operation = 1 [ default = SUM ]; // element-wise operation
repeated float coeff = 2; // blob-wise coefficient for SUM operation
// Whether to use an asymptotically slower (for >2 inputs) but stabler method
// of computing the gradient for the PROD operation. (No effect for SUM op.)
optional bool stable_prod_grad = 3 [ default = true ];
}
// Message that stores parameters used by ELULayer
message ELUParameter {
// Described in:
// Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate
// Deep Network Learning by Exponential Linear Units (ELUs). arXiv
optional float alpha = 1 [ default = 1 ];
}
// Message that stores parameters used by EmbedLayer
message EmbedParameter {
optional uint32 num_output = 1; // The number of outputs for the layer
// The input is given as integers to be interpreted as one-hot
// vector indices with dimension num_input. Hence num_input should be
// 1 greater than the maximum possible input value.
optional uint32 input_dim = 2;
optional bool bias_term = 3 [ default = true ]; // Whether to use a bias term
optional FillerParameter weight_filler = 4; // The filler for the weight
optional FillerParameter bias_filler = 5; // The filler for the bias
}
// Message that stores parameters used by ExpLayer
message ExpParameter {
// ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0.
// Or if base is set to the default (-1), base is set to e,
// so y = exp(shift + scale * x).
optional float base = 1 [ default = -1.0 ];
optional float scale = 2 [ default = 1.0 ];
optional float shift = 3 [ default = 0.0 ];
}
/// Message that stores parameters used by FlattenLayer
message FlattenParameter {
// The first axis to flatten: all preceding axes are retained in the output.
// May be negative to index from the end (e.g., -1 for the last axis).
optional int32 axis = 1 [ default = 1 ];
// The last axis to flatten: all following axes are retained in the output.
// May be negative to index from the end (e.g., the default -1 for the last
// axis).
optional int32 end_axis = 2 [ default = -1 ];
}
// Message that stores parameters used by HDF5DataLayer
message HDF5DataParameter {
// Specify the data source.
optional string source = 1;
// Specify the batch size.
optional uint32 batch_size = 2;
// Specify whether to shuffle the data.
// If shuffle == true, the ordering of the HDF5 files is shuffled,
// and the ordering of data within any given HDF5 file is shuffled,
// but data between different files are not interleaved; all of a file's
// data are output (in a random order) before moving onto another file.
optional bool shuffle = 3 [ default = false ];
}
message HDF5OutputParameter { optional string file_name = 1; }
message HingeLossParameter {
enum Norm {
L1 = 1;
L2 = 2;
}
// Specify the Norm to use L1 or L2
optional Norm norm = 1 [ default = L1 ];
}
message ImageDataParameter {
// Specify the data source.
optional string source = 1;
// Specify the batch size.
optional uint32 batch_size = 4 [ default = 1 ];
// The rand_skip variable is for the data layer to skip a few data points
// to avoid all asynchronous sgd clients to start at the same point. The skip
// point would be set as rand_skip * rand(0,1). Note that rand_skip should not
// be larger than the number of keys in the database.
optional uint32 rand_skip = 7 [ default = 0 ];
// Whether or not ImageLayer should shuffle the list of files at every epoch.
optional bool shuffle = 8 [ default = false ];
// It will also resize images if new_height or new_width are not zero.
optional uint32 new_height = 9 [ default = 0 ];
optional uint32 new_width = 10 [ default = 0 ];
// Specify if the images are color or gray
optional bool is_color = 11 [ default = true ];
// DEPRECATED. See TransformationParameter. For data pre-processing, we can do
// simple scaling and subtracting the data mean, if provided. Note that the
// mean subtraction is always carried out before scaling.
optional float scale = 2 [ default = 1 ];
optional string mean_file = 3;
// DEPRECATED. See TransformationParameter. Specify if we would like to
// randomly
// crop an image.
optional uint32 crop_size = 5 [ default = 0 ];
// DEPRECATED. See TransformationParameter. Specify if we want to randomly
// mirror
// data.
optional bool mirror = 6 [ default = false ];
optional string root_folder = 12 [ default = "" ];
}
message InfogainLossParameter {
// Specify the infogain matrix source.
optional string source = 1;
optional int32 axis = 2 [ default = 1 ]; // axis of prob
}
message InnerProductParameter {
optional uint32 num_output = 1; // The number of outputs for the layer
optional bool bias_term = 2 [ default = true ]; // whether to have bias terms
optional FillerParameter weight_filler = 3; // The filler for the weight
optional FillerParameter bias_filler = 4; // The filler for the bias
// The first axis to be lumped into a single inner product computation;
// all preceding axes are retained in the output.
// May be negative to index from the end (e.g., -1 for the last axis).
optional int32 axis = 5 [ default = 1 ];
// Specify whether to transpose the weight matrix or not.
// If transpose == true, any operations will be performed on the transpose
// of the weight matrix. The weight matrix itself is not going to be
// transposed
// but rather the transfer flag of operations will be toggled accordingly.
optional bool transpose = 6 [ default = false ];
}
message InputParameter {
// This layer produces N >= 1 top blob(s) to be assigned manually.
// Define N shapes to set a shape for each top.
// Define 1 shape to set the same shape for every top.
// Define no shape to defer to reshaping manually.
repeated BlobShape shape = 1;
}
// Message that stores parameters used by LogLayer
message LogParameter {
// LogLayer computes outputs y = log_base(shift + scale * x), for base > 0.
// Or if base is set to the default (-1), base is set to e,
// so y = ln(shift + scale * x) = log_e(shift + scale * x)
optional float base = 1 [ default = -1.0 ];
optional float scale = 2 [ default = 1.0 ];
optional float shift = 3 [ default = 0.0 ];
}
// Message that stores parameters used by LRNLayer
message LRNParameter {
optional uint32 local_size = 1 [ default = 5 ];
optional float alpha = 2 [ default = 1. ];
optional float beta = 3 [ default = 0.75 ];
enum NormRegion {
ACROSS_CHANNELS = 0;
WITHIN_CHANNEL = 1;
}
optional NormRegion norm_region = 4 [ default = ACROSS_CHANNELS ];
optional float k = 5 [ default = 1. ];
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 6 [ default = DEFAULT ];
}
message MemoryDataParameter {
optional uint32 batch_size = 1;
optional uint32 channels = 2;
optional uint32 height = 3;
optional uint32 width = 4;
}
message MVNParameter {
// This parameter can be set to false to normalize mean only
optional bool normalize_variance = 1 [ default = true ];
// This parameter can be set to true to perform DNN-like MVN
optional bool across_channels = 2 [ default = false ];
// Epsilon for not dividing by zero while normalizing variance
optional float eps = 3 [ default = 1e-9 ];
}
message ParameterParameter { optional BlobShape shape = 1; }
message PoolingParameter {
enum PoolMethod {
MAX = 0;
AVE = 1;
STOCHASTIC = 2;
}
optional PoolMethod pool = 1 [ default = MAX ]; // The pooling method
// Pad, kernel size, and stride are all given as a single value for equal
// dimensions in height and width or as Y, X pairs.
optional uint32 pad = 4 [ default = 0 ]; // The padding size (equal in Y, X)
optional uint32 pad_h = 9 [ default = 0 ]; // The padding height
optional uint32 pad_w = 10 [ default = 0 ]; // The padding width
optional uint32 kernel_size = 2; // The kernel size (square)
optional uint32 kernel_h = 5; // The kernel height
optional uint32 kernel_w = 6; // The kernel width
optional uint32 stride = 3 [ default = 1 ]; // The stride (equal in Y, X)
optional uint32 stride_h = 7; // The stride height
optional uint32 stride_w = 8; // The stride width
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 11 [ default = DEFAULT ];
// If global_pooling then it will pool over the size of the bottom by doing
// kernel_h = bottom->height and kernel_w = bottom->width
optional bool global_pooling = 12 [ default = false ];
}
message PowerParameter {
// PowerLayer computes outputs y = (shift + scale * x) ^ power.
optional float power = 1 [ default = 1.0 ];
optional float scale = 2 [ default = 1.0 ];
optional float shift = 3 [ default = 0.0 ];
}
message PythonParameter {
optional string module = 1;
optional string layer = 2;
// This value is set to the attribute `param_str` of the `PythonLayer` object
// in Python before calling the `setup()` method. This could be a number,
// string, dictionary in Python dict format, JSON, etc. You may parse this
// string in `setup` method and use it in `forward` and `backward`.
optional string param_str = 3 [ default = ''];
// DEPRECATED
optional bool share_in_parallel = 4 [ default = false ];
}
// Message that stores parameters used by RecurrentLayer
message RecurrentParameter {
// The dimension of the output (and usually hidden state) representation --
// must be explicitly set to non-zero.
optional uint32 num_output = 1 [ default = 0 ];
optional FillerParameter weight_filler = 2; // The filler for the weight
optional FillerParameter bias_filler = 3; // The filler for the bias
// Whether to enable displaying debug_info in the unrolled recurrent net.
optional bool debug_info = 4 [ default = false ];
// Whether to add as additional inputs (bottoms) the initial hidden state
// blobs, and add as additional outputs (tops) the final timestep hidden state
// blobs. The number of additional bottom/top blobs required depends on the
// recurrent architecture -- e.g., 1 for RNNs, 2 for LSTMs.
optional bool expose_hidden = 5 [ default = false ];
}
// Message that stores parameters used by ReductionLayer
message ReductionParameter {
enum ReductionOp {
SUM = 1;
ASUM = 2;
SUMSQ = 3;
MEAN = 4;
}
optional ReductionOp operation = 1 [ default = SUM ]; // reduction operation
// The first axis to reduce to a scalar -- may be negative to index from the
// end (e.g., -1 for the last axis).
// (Currently, only reduction along ALL "tail" axes is supported; reduction
// of axis M through N, where N < num_axes - 1, is unsupported.)
// Suppose we have an n-axis bottom Blob with shape:
// (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)).
// If axis == m, the output Blob will have shape
// (d0, d1, d2, ..., d(m-1)),
// and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1))
// times, each including (dm * d(m+1) * ... * d(n-1)) individual data.
// If axis == 0 (the default), the output Blob always has the empty shape
// (count 1), performing reduction across the entire input --
// often useful for creating new loss functions.
optional int32 axis = 2 [ default = 0 ];
optional float coeff = 3 [ default = 1.0 ]; // coefficient for output
}
// Message that stores parameters used by ReLULayer
message ReLUParameter {
// Allow non-zero slope for negative inputs to speed up optimization
// Described in:
// Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities
// improve neural network acoustic models. In ICML Workshop on Deep Learning
// for Audio, Speech, and Language Processing.
optional float negative_slope = 1 [ default = 0 ];
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 2 [ default = DEFAULT ];
}
message ReshapeParameter {
// Specify the output dimensions. If some of the dimensions are set to 0,
// the corresponding dimension from the bottom layer is used (unchanged).
// Exactly one dimension may be set to -1, in which case its value is
// inferred from the count of the bottom blob and the remaining dimensions.
// For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8:
//
// layer {
// type: "Reshape" bottom: "input" top: "output"
// reshape_param { ... }
// }
//
// If "input" is 2D with shape 2 x 8, then the following reshape_param
// specifications are all equivalent, producing a 3D blob "output" with shape
// 2 x 2 x 4:
//
// reshape_param { shape { dim: 2 dim: 2 dim: 4 } }
// reshape_param { shape { dim: 0 dim: 2 dim: 4 } }
// reshape_param { shape { dim: 0 dim: 2 dim: -1 } }
// reshape_param { shape { dim: 0 dim:-1 dim: 4 } }
//
optional BlobShape shape = 1;
// axis and num_axes control the portion of the bottom blob's shape that are
// replaced by (included in) the reshape. By default (axis == 0 and
// num_axes == -1), the entire bottom blob shape is included in the reshape,
// and hence the shape field must specify the entire output shape.
//
// axis may be non-zero to retain some portion of the beginning of the input
// shape (and may be negative to index from the end; e.g., -1 to begin the
// reshape after the last axis, including nothing in the reshape,
// -2 to include only the last axis, etc.).
//
// For example, suppose "input" is a 2D blob with shape 2 x 8.
// Then the following ReshapeLayer specifications are all equivalent,
// producing a blob "output" with shape 2 x 2 x 4:
//
// reshape_param { shape { dim: 2 dim: 2 dim: 4 } }
// reshape_param { shape { dim: 2 dim: 4 } axis: 1 }
// reshape_param { shape { dim: 2 dim: 4 } axis: -3 }
//
// num_axes specifies the extent of the reshape.
// If num_axes >= 0 (and axis >= 0), the reshape will be performed only on
// input axes in the range [axis, axis+num_axes].
// num_axes may also be -1, the default, to include all remaining axes
// (starting from axis).
//
// For example, suppose "input" is a 2D blob with shape 2 x 8.
// Then the following ReshapeLayer specifications are equivalent,
// producing a blob "output" with shape 1 x 2 x 8.
//
// reshape_param { shape { dim: 1 dim: 2 dim: 8 } }
// reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 }
// reshape_param { shape { dim: 1 } num_axes: 0 }
//
// On the other hand, these would produce output blob shape 2 x 1 x 8:
//
// reshape_param { shape { dim: 2 dim: 1 dim: 8 } }
// reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 }
//
optional int32 axis = 2 [ default = 0 ];
optional int32 num_axes = 3 [ default = -1 ];
}
message ScaleParameter {
// The first axis of bottom[0] (the first input Blob) along which to apply
// bottom[1] (the second input Blob). May be negative to index from the end
// (e.g., -1 for the last axis).
//
// For example, if bottom[0] is 4D with shape 100x3x40x60, the output
// top[0] will have the same shape, and bottom[1] may have any of the
// following shapes (for the given value of axis):
// (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60
// (axis == 1 == -3) 3; 3x40; 3x40x60
// (axis == 2 == -2) 40; 40x60
// (axis == 3 == -1) 60
// Furthermore, bottom[1] may have the empty shape (regardless of the value of
// "axis") -- a scalar multiplier.
optional int32 axis = 1 [ default = 1 ];
// (num_axes is ignored unless just one bottom is given and the scale is
// a learned parameter of the layer. Otherwise, num_axes is determined by the
// number of axes by the second bottom.)
// The number of axes of the input (bottom[0]) covered by the scale
// parameter, or -1 to cover all axes of bottom[0] starting from `axis`.
// Set num_axes := 0, to multiply with a zero-axis Blob: a scalar.
optional int32 num_axes = 2 [ default = 1 ];
// (filler is ignored unless just one bottom is given and the scale is
// a learned parameter of the layer.)
// The initialization for the learned scale parameter.
// Default is the unit (1) initialization, resulting in the ScaleLayer
// initially performing the identity operation.
optional FillerParameter filler = 3;
// Whether to also learn a bias (equivalent to a ScaleLayer+BiasLayer, but
// may be more efficient). Initialized with bias_filler (defaults to 0).
optional bool bias_term = 4 [ default = false ];
optional FillerParameter bias_filler = 5;
}
message SigmoidParameter {
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 1 [ default = DEFAULT ];
}
message SliceParameter {
// The axis along which to slice -- may be negative to index from the end
// (e.g., -1 for the last axis).
// By default, SliceLayer concatenates blobs along the "channels" axis (1).
optional int32 axis = 3 [ default = 1 ];
repeated uint32 slice_point = 2;
// DEPRECATED: alias for "axis" -- does not support negative indexing.
optional uint32 slice_dim = 1 [ default = 1 ];
}
// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer
message SoftmaxParameter {
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 1 [ default = DEFAULT ];
// The axis along which to perform the softmax -- may be negative to index
// from the end (e.g., -1 for the last axis).
// Any other axes will be evaluated as independent softmaxes.
optional int32 axis = 2 [ default = 1 ];
}
message TanHParameter {
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 1 [ default = DEFAULT ];
}
// Message that stores parameters used by TileLayer
message TileParameter {
// The index of the axis to tile.
optional int32 axis = 1 [ default = 1 ];
// The number of copies (tiles) of the blob to output.
optional int32 tiles = 2;
}
// Message that stores parameters used by ThresholdLayer
message ThresholdParameter {
optional float threshold = 1 [ default = 0 ]; // Strictly positive values
}
message WindowDataParameter {
// Specify the data source.
optional string source = 1;
// For data pre-processing, we can do simple scaling and subtracting the
// data mean, if provided. Note that the mean subtraction is always carried
// out before scaling.
optional float scale = 2 [ default = 1 ];
optional string mean_file = 3;
// Specify the batch size.
optional uint32 batch_size = 4;
// Specify if we would like to randomly crop an image.
optional uint32 crop_size = 5 [ default = 0 ];
// Specify if we want to randomly mirror data.
optional bool mirror = 6 [ default = false ];
// Foreground (object) overlap threshold
optional float fg_threshold = 7 [ default = 0.5 ];
// Background (non-object) overlap threshold
optional float bg_threshold = 8 [ default = 0.5 ];
// Fraction of batch that should be foreground objects
optional float fg_fraction = 9 [ default = 0.25 ];
// Amount of contextual padding to add around a window
// (used only by the window_data_layer)
optional uint32 context_pad = 10 [ default = 0 ];
// Mode for cropping out a detection window
// warp: cropped window is warped to a fixed size and aspect ratio
// square: the tightest square around the window is cropped
optional string crop_mode = 11 [ default = "warp" ];
// cache_images: will load all images in memory for faster access
optional bool cache_images = 12 [ default = false ];
// append root_folder to locate images
optional string root_folder = 13 [ default = "" ];
}
message SPPParameter {
enum PoolMethod {
MAX = 0;
AVE = 1;
STOCHASTIC = 2;
}
optional uint32 pyramid_height = 1;
optional PoolMethod pool = 2 [ default = MAX ]; // The pooling method
enum Engine {
DEFAULT = 0;
CAFFE = 1;
CUDNN = 2;
}
optional Engine engine = 6 [ default = DEFAULT ];
}
// DEPRECATED: use LayerParameter.
message V1LayerParameter {
repeated string bottom = 2;
repeated string top = 3;
optional string name = 4;
repeated NetStateRule include = 32;
repeated NetStateRule exclude = 33;
enum LayerType {
NONE = 0;
ABSVAL = 35;
ACCURACY = 1;
ARGMAX = 30;
BNLL = 2;
CONCAT = 3;
CONTRASTIVE_LOSS = 37;
CONVOLUTION = 4;
DATA = 5;
DECONVOLUTION = 39;
DROPOUT = 6;
DUMMY_DATA = 32;
EUCLIDEAN_LOSS = 7;
ELTWISE = 25;
EXP = 38;
FLATTEN = 8;
HDF5_DATA = 9;
HDF5_OUTPUT = 10;
HINGE_LOSS = 28;
IM2COL = 11;
IMAGE_DATA = 12;
INFOGAIN_LOSS = 13;
INNER_PRODUCT = 14;
LRN = 15;
MEMORY_DATA = 29;
MULTINOMIAL_LOGISTIC_LOSS = 16;
MVN = 34;
POOLING = 17;
POWER = 26;
RELU = 18;
SIGMOID = 19;
SIGMOID_CROSS_ENTROPY_LOSS = 27;
SILENCE = 36;
SOFTMAX = 20;
SOFTMAX_LOSS = 21;
SPLIT = 22;
SLICE = 33;
TANH = 23;
WINDOW_DATA = 24;
THRESHOLD = 31;
}
optional LayerType type = 5;
repeated BlobProto blobs = 6;
repeated string param = 1001;
repeated DimCheckMode blob_share_mode = 1002;
enum DimCheckMode {
STRICT = 0;
PERMISSIVE = 1;
}
repeated float blobs_lr = 7;
repeated float weight_decay = 8;
repeated float loss_weight = 35;
optional AccuracyParameter accuracy_param = 27;
optional ArgMaxParameter argmax_param = 23;
optional ConcatParameter concat_param = 9;
optional ContrastiveLossParameter contrastive_loss_param = 40;
optional ConvolutionParameter convolution_param = 10;
optional DataParameter data_param = 11;
optional DropoutParameter dropout_param = 12;
optional DummyDataParameter dummy_data_param = 26;
optional EltwiseParameter eltwise_param = 24;
optional ExpParameter exp_param = 41;
optional HDF5DataParameter hdf5_data_param = 13;
optional HDF5OutputParameter hdf5_output_param = 14;
optional HingeLossParameter hinge_loss_param = 29;
optional ImageDataParameter image_data_param = 15;
optional InfogainLossParameter infogain_loss_param = 16;
optional InnerProductParameter inner_product_param = 17;
optional LRNParameter lrn_param = 18;
optional MemoryDataParameter memory_data_param = 22;
optional MVNParameter mvn_param = 34;
optional PoolingParameter pooling_param = 19;
optional PowerParameter power_param = 21;
optional ReLUParameter relu_param = 30;
optional SigmoidParameter sigmoid_param = 38;
optional SoftmaxParameter softmax_param = 39;
optional SliceParameter slice_param = 31;
optional TanHParameter tanh_param = 37;
optional ThresholdParameter threshold_param = 25;
optional WindowDataParameter window_data_param = 20;
optional TransformationParameter transform_param = 36;
optional LossParameter loss_param = 42;
optional V0LayerParameter layer = 1;
}
// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters
// in Caffe. We keep this message type around for legacy support.
message V0LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the string to specify the layer type
// Parameters to specify layers with inner products.
optional uint32 num_output = 3; // The number of outputs for the layer
optional bool biasterm = 4 [ default = true ]; // whether to have bias terms
optional FillerParameter weight_filler = 5; // The filler for the weight
optional FillerParameter bias_filler = 6; // The filler for the bias
optional uint32 pad = 7 [ default = 0 ]; // The padding size
optional uint32 kernelsize = 8; // The kernel size
optional uint32 group = 9 [ default = 1 ]; // The group size for group conv
optional uint32 stride = 10 [ default = 1 ]; // The stride
enum PoolMethod {
MAX = 0;
AVE = 1;
STOCHASTIC = 2;
}
optional PoolMethod pool = 11 [ default = MAX ]; // The pooling method
optional float dropout_ratio = 12 [ default = 0.5 ]; // dropout ratio
optional uint32 local_size = 13 [ default = 5 ]; // for local response norm
optional float alpha = 14 [ default = 1. ]; // for local response norm
optional float beta = 15 [ default = 0.75 ]; // for local response norm
optional float k = 22 [ default = 1. ];
// For data layers, specify the data source
optional string source = 16;
// For data pre-processing, we can do simple scaling and subtracting the
// data mean, if provided. Note that the mean subtraction is always carried
// out before scaling.
optional float scale = 17 [ default = 1 ];
optional string meanfile = 18;
// For data layers, specify the batch size.
optional uint32 batchsize = 19;
// For data layers, specify if we would like to randomly crop an image.
optional uint32 cropsize = 20 [ default = 0 ];
// For data layers, specify if we want to randomly mirror data.
optional bool mirror = 21 [ default = false ];
// The blobs containing the numeric parameters of the layer
repeated BlobProto blobs = 50;
// The ratio that is multiplied on the global learning rate. If you want to
// set the learning ratio for one blob, you need to set it for all blobs.
repeated float blobs_lr = 51;
// The weight decay that is multiplied on the global weight decay.
repeated float weight_decay = 52;
// The rand_skip variable is for the data layer to skip a few data points
// to avoid all asynchronous sgd clients to start at the same point. The skip
// point would be set as rand_skip * rand(0,1). Note that rand_skip should not
// be larger than the number of keys in the database.
optional uint32 rand_skip = 53 [ default = 0 ];
// Fields related to detection (det_*)
// foreground (object) overlap threshold
optional float det_fg_threshold = 54 [ default = 0.5 ];
// background (non-object) overlap threshold
optional float det_bg_threshold = 55 [ default = 0.5 ];
// Fraction of batch that should be foreground objects
optional float det_fg_fraction = 56 [ default = 0.25 ];
// optional bool OBSOLETE_can_clobber = 57 [default = true];
// Amount of contextual padding to add around a window
// (used only by the window_data_layer)
optional uint32 det_context_pad = 58 [ default = 0 ];
// Mode for cropping out a detection window
// warp: cropped window is warped to a fixed size and aspect ratio
// square: the tightest square around the window is cropped
optional string det_crop_mode = 59 [ default = "warp" ];
// For ReshapeLayer, one needs to specify the new dimensions.
optional int32 new_num = 60 [ default = 0 ];
optional int32 new_channels = 61 [ default = 0 ];
optional int32 new_height = 62 [ default = 0 ];
optional int32 new_width = 63 [ default = 0 ];
// Whether or not ImageLayer should shuffle the list of files at every epoch.
// It will also resize images if new_height or new_width are not zero.
optional bool shuffle_images = 64 [ default = false ];
// For ConcatLayer, one needs to specify the dimension for concatenation, and
// the other dimensions must be the same for all the bottom blobs.
// By default it will concatenate blobs along the channels dimension.
optional uint32 concat_dim = 65 [ default = 1 ];
optional HDF5OutputParameter hdf5_output_param = 1001;
}
message PReLUParameter {
// Parametric ReLU described in K. He et al, Delving Deep into Rectifiers:
// Surpassing Human-Level Performance on ImageNet Classification, 2015.
// Initial value of a_i. Default is a_i=0.25 for all i.
optional FillerParameter filler = 1;
// Whether or not slope parameters are shared across channels.
optional bool channel_shared = 2 [ default = false ];
}
#!/bin/bash
#function:
# script used to generate caffepb.py from caffe.proto using protoc
#
PROTOC=`which protoc`
if [[ -z $PROTOC ]];then
echo "not found protoc, you should first install it following this[https://github.com/google/protobuf/releases]"
exit 1
fi
WORK_ROOT=$(dirname `readlink -f "$BASH_SOURCE[0]"`)
PY_NAME="$WORK_ROOT/caffe_pb2.py"
$PROTOC --proto_path=$WORK_ROOT --python_out=$WORK_ROOT $WORK_ROOT/caffe.proto
ret=$?
if [ -e "$PY_NAME" ];then
echo "succeed to generate [$PY_NAME]"
exit 0
else
echo "failed to generate [$PY_NAME]"
fi
exit $ret
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册