diff --git a/fluid/image_classification/caffe2fluid/README.md b/fluid/image_classification/caffe2fluid/README.md index 279b4c6e57a785736a1c75928de8d45f4e4e956e..5f565afe0c33db291092faeac632da3d51f95613 100644 --- a/fluid/image_classification/caffe2fluid/README.md +++ b/fluid/image_classification/caffe2fluid/README.md @@ -2,7 +2,8 @@ This tool is used to convert a Caffe model to Fluid model ### Howto -1, Prepare caffepb.py in ./proto, two options provided +1, Prepare caffepb.py in ./proto if your python has no 'pycaffe' module, two options provided here: + 1) generate it from caffe.proto using protoc bash ./proto/compile.sh @@ -12,14 +13,24 @@ This tool is used to convert a Caffe model to Fluid model 2, Convert the caffe model using 'convert.py' which will generate a python script and a weight(in .npy) file 3, Use the converted model to predict - see more detail info in 'tests/lenet/README.md' + + see more detail info in 'examples/xxx' -### Supported models +### Tested models - Lenet on mnist dataset - ResNets:(ResNet-50, ResNet-101, ResNet-152) - model addrs:(https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777) + model addr: `https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777`_ + +- GoogleNet: + model addr: `https://gist.github.com/jimmie33/7ea9f8ac0da259866b854460f4526034`_ + +- VGG: + model addr: `https://gist.github.com/ksimonyan/211839e770f7b538e2d8`_ + +- AlexNet: + model addr: `https://github.com/BVLC/caffe/tree/master/models/bvlc_alexnet`_ ### Notes Some of this code come from here: https://github.com/ethereon/caffe-tensorflow diff --git a/fluid/image_classification/caffe2fluid/convert.py b/fluid/image_classification/caffe2fluid/convert.py index 68a9e4f7e490a69c1b582d6fc14b2015bfdf9536..44420c837b0bc4401f7e0f54e3a184af57e53d9a 100755 --- a/fluid/image_classification/caffe2fluid/convert.py +++ b/fluid/image_classification/caffe2fluid/convert.py @@ -4,8 +4,8 @@ import os import sys import numpy as np import argparse -from kaffe import KaffeError, print_stderr +from kaffe import KaffeError, print_stderr from kaffe.paddle import Transformer @@ -47,6 +47,7 @@ def convert(def_path, caffemodel_path, data_output_path, code_output_path, except KaffeError as err: fatal_error('Error encountered: {}'.format(err)) + return 0 def main(): """ main @@ -69,4 +70,5 @@ def main(): if __name__ == '__main__': - main() + ret = main() + sys.exit(ret) diff --git a/fluid/image_classification/caffe2fluid/examples/imagenet/README.md b/fluid/image_classification/caffe2fluid/examples/imagenet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b82050859239be8804ddec8e2054edc38c4ac052 --- /dev/null +++ b/fluid/image_classification/caffe2fluid/examples/imagenet/README.md @@ -0,0 +1,10 @@ +a demo to show converting caffe models on 'imagenet' using caffe2fluid + +--- + +# How to use + +1. prepare python environment +2. download caffe model to "models.caffe/xxx" which contains "xxx.caffemodel" and "xxx.prototxt" +3. run the tool + eg: bash ./run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50 diff --git a/fluid/image_classification/caffe2fluid/examples/imagenet/data/65.jpeg b/fluid/image_classification/caffe2fluid/examples/imagenet/data/65.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..fd3a93f59385d6ff632483646e6caee300b56d09 Binary files /dev/null and b/fluid/image_classification/caffe2fluid/examples/imagenet/data/65.jpeg differ diff --git a/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py b/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..5465e3da8475feb46102551fe1ca40f28cd13b7e --- /dev/null +++ b/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py @@ -0,0 +1,139 @@ +#!/bin/env python + +#function: +# a demo to show how to use the converted model genereated by caffe2fluid +# +#notes: +# only support imagenet data + +import os +import sys +import inspect +import numpy as np +import paddle.v2 as paddle +import paddle.v2.fluid as fluid + +def load_data(imgfile, shape): + h, w = shape[1:] + from PIL import Image + im = Image.open(imgfile) + + # The storage order of the loaded image is W(widht), + # H(height), C(channel). PaddlePaddle requires + # the CHW order, so transpose them. + im = im.resize((w, h), Image.ANTIALIAS) + im = np.array(im).astype(np.float32) + im = im.transpose((2, 0, 1)) # CHW + im = im[(2, 1, 0), :, :] # BGR + + # The mean to be subtracted from each image. + # By default, the per-channel ImageNet mean. + mean = np.array([104., 117., 124.], dtype=np.float32) + mean = mean.reshape([3, 1, 1]) + im = im - mean + return im.reshape([1] + shape) + + +def build_model(net_file, net_name): + print('build model with net_file[%s] and net_name[%s]' % (net_file, net_name)) + + net_path = os.path.dirname(net_file) + module_name = os.path.basename(net_file).rstrip('.py') + if net_path not in sys.path: + sys.path.insert(0, net_path) + + try: + m = __import__(module_name, fromlist=[net_name]) + MyNet = getattr(m, net_name) + except Exception as e: + print('failed to load module[%s]' % (module_name)) + print(e) + return None + + input_name = 'data' + input_shape = MyNet.input_shapes()[input_name] + images = fluid.layers.data(name='image', shape=input_shape, dtype='float32') + #label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + net = MyNet({input_name: images}) + input_shape = MyNet.input_shapes()[input_name] + return net, input_shape + + +def dump_results(results, names, root): + if os.path.exists(root) is False: + os.path.mkdir(root) + + for i in range(len(names)): + n = names[i] + res = results[i] + filename = os.path.join(root, n) + np.save(filename + '.npy', res) + + +def infer(net_file, net_name, model_file, imgfile, debug=False): + """ do inference using a model which consist 'xxx.py' and 'xxx.npy' + """ + #1, build model + net, input_shape = build_model(net_file, net_name) + prediction = net.get_output() + + #2, load weights for this model + place = fluid.CPUPlace() + exe = fluid.Executor(place) + startup_program = fluid.default_startup_program() + exe.run(startup_program) + + if model_file.find('.npy') > 0: + net.load(data_path=model_file, exe=exe, place=place) + else: + net.load(data_path=model_file, exe=exe) + + #3, test this model + test_program = fluid.default_main_program().clone() + + fetch_list_var = [] + fetch_list_name = [] + if debug is False: + fetch_list_var.append(prediction) + else: + for k, v in net.layers.items(): + fetch_list_var.append(v) + fetch_list_name.append(k) + + np_images = load_data(imgfile, input_shape) + results = exe.run(program=test_program, + feed={'image': np_images}, + fetch_list=fetch_list_var) + + if debug is True: + dump_path = 'results.layers' + dump_results(results, fetch_list_name, dump_path) + print('all results dumped to [%s]' % (dump_path)) + else: + result = results[0] + print('predicted class:', np.argmax(result)) + + +if __name__ == "__main__": + """ maybe more convenient to use 'run.sh' to call this tool + """ + net_file = 'models/resnet50/resnet50.py' + weight_file = 'models/resnet50/resnet50.npy' + imgfile = 'data/65.jpeg' + net_name = 'ResNet50' + + argc = len(sys.argv) + if argc == 5: + net_file = sys.argv[1] + weight_file = sys.argv[2] + imgfile = sys.argv[3] + net_name= sys.argv[4] + elif argc > 1: + print('usage:') + print('\tpython %s [net_file] [weight_file] [imgfile] [net_name]' % (sys.argv[0])) + print('\teg:python %s %s %s %s %s' % (sys.argv[0], + net_file, weight_file, imgfile, net_name)) + sys.exit(1) + + infer(net_file, net_name, weight_file, imgfile) diff --git a/fluid/image_classification/caffe2fluid/examples/imagenet/run.sh b/fluid/image_classification/caffe2fluid/examples/imagenet/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..7a1a5ebd7c0a5090c00a0c8ca6b0e11b110967dc --- /dev/null +++ b/fluid/image_classification/caffe2fluid/examples/imagenet/run.sh @@ -0,0 +1,72 @@ +#!/bin/bash + +#function: +# a tool used to: +# 1, convert a caffe model +# 2, do inference using this model +# +#usage: +# bash run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50 +# + +#set -x +if [[ $# -lt 3 ]];then + echo "usage:" + echo " bash $0 [model_name] [cf_model_path] [pd_model_path] [only_convert]" + echo " eg: bash $0 resnet50 ./models.caffe/resnet50 ./models/resnet50" + exit 1 +else + model_name=$1 + cf_model_path=$2 + pd_model_path=$3 + only_convert=$4 +fi + +proto_file=$cf_model_path/${model_name}.prototxt +caffemodel_file=$cf_model_path/${model_name}.caffemodel +weight_file=$pd_model_path/${model_name}.npy +net_file=$pd_model_path/${model_name}.py + +if [[ ! -e $proto_file ]];then + echo "not found prototxt[$proto_file]" + exit 1 +fi + +if [[ ! -e $caffemodel_file ]];then + echo "not found caffemodel[$caffemodel_file]" + exit 1 +fi + +if [[ ! -e $pd_model_path ]];then + mkdir $pd_model_path +fi + +PYTHON=`which cfpython` +if [[ -z $PYTHON ]];then + PYTHON=`which python` +fi +$PYTHON ../../convert.py \ + $proto_file \ + --caffemodel $caffemodel_file \ + --data-output-path $weight_file\ + --code-output-path $net_file + +ret=$? +if [[ $ret -ne 0 ]];then + echo "failed to convert caffe model[$cf_model_path]" + exit $ret +else + echo "succeed to convert caffe model[$cf_model_path] to fluid model[$pd_model_path]" +fi + +if [[ -z $only_convert ]];then + PYTHON=`which pdpython` + if [[ -z $PYTHON ]];then + PYTHON=`which python` + fi + imgfile="data/65.jpeg" + net_name=`grep "name" $proto_file | head -n1 | perl -ne 'if(/\"([^\"]+)\"/){ print $1."\n";}'` + $PYTHON ./infer.py $net_file $weight_file $imgfile $net_name + ret=$? +fi +exit $ret diff --git a/fluid/image_classification/caffe2fluid/examples/mnist/README.md b/fluid/image_classification/caffe2fluid/examples/mnist/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cd427d632737c8988403f987d86c159500022198 --- /dev/null +++ b/fluid/image_classification/caffe2fluid/examples/mnist/README.md @@ -0,0 +1,10 @@ +a demo to show converting caffe model on 'mnist' using caffe2fluid + +--- + +# How to use + +1. prepare python environment +2. download caffe model to "models.caffe/lenet" which contains "lenet.caffemodel" and "lenet.prototxt" +3. run the tool + eg: bash ./run.sh lenet ./models.caffe/lenet ./models/lenet diff --git a/fluid/image_classification/caffe2fluid/examples/mnist/evaluate.py b/fluid/image_classification/caffe2fluid/examples/mnist/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..81b67342832d4b6569f769d0dd9f40ddc2ea3699 --- /dev/null +++ b/fluid/image_classification/caffe2fluid/examples/mnist/evaluate.py @@ -0,0 +1,86 @@ +#!/bin/env python + +#function: +# demo to show how to use converted model using caffe2fluid +# + +import sys +import os +import numpy as np +import paddle.v2 as paddle +import paddle.v2.fluid as fluid + +def test_model(exe, test_program, fetch_list, test_reader, feeder): + acc_set = [] + + for data in test_reader(): + acc_np, pred = exe.run(program=test_program, + feed=feeder.feed(data), + fetch_list=fetch_list) + acc_set.append(float(acc_np)) + + acc_val = np.array(acc_set).mean() + return float(acc_val) + + +def evaluate(net_file, model_file): + """ main + """ + #1, build model + net_path = os.path.dirname(net_file) + if net_path not in sys.path: + sys.path.insert(0, net_path) + + from lenet import LeNet as MyNet + + with_gpu = False + paddle.init(use_gpu=with_gpu) + + #1, define network topology + images = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + net = MyNet({'data': images}) + prediction = net.layers['prob'] + acc = fluid.layers.accuracy(input=prediction, label=label) + + place = fluid.CUDAPlace(0) if with_gpu is True else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + #2, load weights + if model_file.find('.npy') > 0: + net.load(data_path=model_file, exe=exe, place=place) + else: + net.load(data_path=model_file, exe=exe) + + #3, test this model + test_program = fluid.default_main_program().clone() + test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + + feeder = fluid.DataFeeder(feed_list=[images, label], place=place) + fetch_list = [acc, prediction] + + print('go to test model using test set') + acc_val = test_model(exe, test_program, \ + fetch_list, test_reader, feeder) + + print('test accuracy is [%.4f], expected value[0.919]' % (acc_val)) + + +if __name__ == "__main__": + net_file = 'models/lenet/lenet.py' + weight_file = 'models/lenet/lenet.npy' + + argc = len(sys.argv) + if argc == 3: + net_file = sys.argv[1] + weight_file = sys.argv[2] + elif argc > 1: + print('usage:') + print('\tpython %s [net_file] [weight_file]' % (sys.argv[0])) + print('\teg:python %s %s %s %s' % (sys.argv[0], + net_file, weight_file)) + sys.exit(1) + + evaluate(net_file, weight_file) diff --git a/fluid/image_classification/caffe2fluid/examples/mnist/run.sh b/fluid/image_classification/caffe2fluid/examples/mnist/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..eee83ef7cefd594c62fd95db525f081a27c6ea38 --- /dev/null +++ b/fluid/image_classification/caffe2fluid/examples/mnist/run.sh @@ -0,0 +1,75 @@ +#!/bin/bash + +#function: +# a tool used to: +# 1, convert a caffe model +# 2, do inference using this model +# +#usage: +# bash run.sh lenet ./models.caffe/lenet ./models/lenet +# + +#set -x +if [[ $# -lt 3 ]];then + echo "usage:" + echo " bash $0 [model_name] [cf_model_path] [pd_model_path] [only_convert]" + echo " eg: bash $0 lenet ./models.caffe/lenet ./models/lenet" + exit 1 +else + model_name=$1 + cf_model_path=$2 + pd_model_path=$3 + no_eval=$4 +fi + +proto_file=$cf_model_path/${model_name}.prototxt +caffemodel_file=$cf_model_path/${model_name}.caffemodel +weight_file=$pd_model_path/${model_name}.npy +net_file=$pd_model_path/${model_name}.py + +if [[ ! -e $proto_file ]];then + echo "not found prototxt[$proto_file]" + exit 1 +fi + +if [[ ! -e $caffemodel_file ]];then + echo "not found caffemodel[$caffemodel_file]" + exit 1 +fi + +if [[ ! -e $pd_model_path ]];then + mkdir $pd_model_path +fi + +PYTHON=`which cfpython` +if [[ -z $PYTHON ]];then + PYTHON=`which python` +fi +$PYTHON ../../convert.py \ + $proto_file \ + --caffemodel $caffemodel_file \ + --data-output-path $weight_file\ + --code-output-path $net_file + +ret=$? +if [[ $ret -ne 0 ]];then + echo "failed to convert caffe model[$cf_model_path]" + exit $ret +else + echo "succeed to convert caffe model[$cf_model_path] to fluid model[$pd_model_path]" +fi + +if [[ -z $only_convert ]];then + PYTHON=`which pdpython` + if [[ -z $PYTHON ]];then + PYTHON=`which python` + fi + net_name=`grep "name" $proto_file | head -n1 | perl -ne 'if(/\"([^\"]+)\"/){ print $1."\n";}'` + if [[ $net_name != "LeNet" ]];then + echo "only support LeNet" + exit 1 + fi + $PYTHON ./evaluate.py $net_file $weight_file + ret=$? +fi +exit $ret diff --git a/fluid/image_classification/caffe2fluid/kaffe/caffe/resolver.py b/fluid/image_classification/caffe2fluid/kaffe/caffe/resolver.py index 5fbd48d3ade5ab4b812210acf82be625871740cb..6ad7767ed8a88f1c0258ad36cc35221c33b641e5 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/caffe/resolver.py +++ b/fluid/image_classification/caffe2fluid/kaffe/caffe/resolver.py @@ -54,7 +54,6 @@ def show_fallback_warning(): WARNING: PyCaffe not found! Falling back to a pure protocol buffer implementation. * Conversions will be drastically slower. - * This backend is UNTESTED! ------------------------------------------------------------ ''' diff --git a/fluid/image_classification/caffe2fluid/kaffe/graph.py b/fluid/image_classification/caffe2fluid/kaffe/graph.py index cb751dffa1ca9cc19214bed12681312942046df6..5387f441852b8a318a41898ee0b62b4903ccdabb 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/graph.py +++ b/fluid/image_classification/caffe2fluid/kaffe/graph.py @@ -175,6 +175,7 @@ class GraphBuilder(object): kind = NodeKind.map_raw_kind(layer.type) if kind is None: raise KaffeError('Unknown layer type encountered: %s' % layer.type) + # We want to use the layer's top names (the "output" names), rather than the # name attribute, which is more of readability thing than a functional one. # Other layers will refer to a node by its "top name". @@ -235,6 +236,7 @@ class GraphBuilder(object): node.add_parent(parent_node) if len(layer.top) > 1: raise KaffeError('Multiple top nodes are not supported.') + for output_name in layer.top: if output_name == layer.name: # Output is named the same as the node. No further action required. diff --git a/fluid/image_classification/caffe2fluid/kaffe/layers.py b/fluid/image_classification/caffe2fluid/kaffe/layers.py index 6be35ed727fed76a1c96017455bdaa354ace9f97..675dab5ffd6098314124ceaf9884e54eb22fbb9f 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/layers.py +++ b/fluid/image_classification/caffe2fluid/kaffe/layers.py @@ -51,17 +51,77 @@ LAYER_DESCRIPTORS = { 'Threshold': shape_identity, } -LAYER_TYPES = LAYER_DESCRIPTORS.keys() +# layer types in 'V1LayerParameter' +# (v1layertype name, enum value, mapped to layer type) +v1_layertypes = [ + ('ABSVAL', 35), + ('ACCURACY', 1), + ('ARGMAX', 30), + ('BNLL', 2), + ('CONCAT', 3), + ('CONVOLUTION', 4), + ('DATA', 5), + ('DECONVOLUTION', 39), + ('DROPOUT', 6), + ('ELTWISE', 25), + ('EXP', 38), + ('FLATTEN', 8), + ('IM2COL', 11), + ('INNERPRODUCT', 14), + ('LRN', 15), + ('MEMORYDATA', 29), + ('MULTINOMIALLOGISTICLOSS', 16), + ('MVN', 34), + ('POOLING', 17), + ('POWER', 26), + ('RELU', 18), + ('SIGMOID', 19), + ('SIGMOIDCROSSENTROPYLOSS', 27), + ('SILENCE', 36), + ('SOFTMAX', 20), + ('SPLIT', 22), + ('SLICE', 33), + ('TANH', 23), + ('WINDOWDATA', 24), + ('THRESHOLD', 31), +] +LAYER_TYPES = LAYER_DESCRIPTORS.keys() LayerType = type('LayerType', (), {t: t for t in LAYER_TYPES}) +#map the layer name in V1 to standard name +V1_LAYER_MAP = {'_not_init_': True} +def get_v1_layer_map(): + global V1_LAYER_MAP + if '_not_init_' not in V1_LAYER_MAP: + return V1_LAYER_MAP + else: + del V1_LAYER_MAP['_not_init_'] + + name2layer = {} + for n in LAYER_TYPES: + name2layer[n.upper()] = n + + for l in v1_layertypes: + n, v = l + if n in name2layer and v not in V1_LAYER_MAP: + V1_LAYER_MAP[v] = name2layer[n] + else: + raise KaffeError('not found v1 layer type %s' % n) + return V1_LAYER_MAP + class NodeKind(LayerType): @staticmethod def map_raw_kind(kind): if kind in LAYER_TYPES: return kind - return None + + v1_layers = get_v1_layer_map() + if kind in v1_layers: + return v1_layers[kind] + else: + return None @staticmethod def compute_output_shape(node): diff --git a/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py b/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py index 620a84e8f1289672151f1f280559a56b37995ce0..866522a874adda4ff34f62b4a8b9b547dbe8d1d4 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py +++ b/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py @@ -27,6 +27,9 @@ def layer(op): self.layers[name] = layer_output # This output is now the input for the next layer. self.feed(layer_output) + #print('output shape of %s:' % (name)) + #print layer_output.shape + # Return self for chained calls. return self @@ -158,41 +161,64 @@ class Network(object): output = fluid.layers.relu(x=input) return output - @layer - def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=None): - if padding is None: - padding = [0, 0] + def _adjust_pad_if_needed(self, i_hw, k_hw, s_hw, p_hw): + #adjust the padding if needed + i_h, i_w = i_hw + k_h, k_w = k_hw + s_h, s_w = s_hw + p_h, p_w = p_hw + + def is_consistent(i, k, s, p): + o = i + 2 * p - k + if o % s == 0: + return True + else: + return False + + real_p_h = 0 + real_p_w = 0 + if is_consistent(i_h, k_h, s_h, p_h) is False: + real_p_h = int(k_h / 2) + if is_consistent(i_w, k_w, s_w, p_w) is False: + real_p_w = int(k_w / 2) + + return [real_p_h, real_p_w] + + def pool(self, pool_type, input, k_h, k_w, s_h, s_w, name, padding): # Get the number of channels in the input - h_i, w_i = input.shape[2:] - fluid = import_fluid() - output = fluid.layers.pool2d( - input=input, - pool_size=[k_h, k_w], - pool_stride=[s_h, s_w], - pool_padding=padding, - pool_type='max') - return output + in_hw = input.shape[2:] + k_hw = [k_h, k_w] + s_hw = [s_h, s_w] - @layer - def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=None): if padding is None: - padding = [0, 0] + #fix bug about the difference between conv and pool + #more info: https://github.com/BVLC/caffe/issues/1318 + padding = self._adjust_pad_if_needed(in_hw, k_hw, s_hw, [0, 0]) - # Get the number of channels in the input - h_i, w_i = input.shape[2:] fluid = import_fluid() output = fluid.layers.pool2d( input=input, - pool_size=[k_h, k_w], - pool_stride=[s_h, s_w], + pool_size=k_hw, + pool_stride=s_hw, pool_padding=padding, - pool_type='avg') + pool_type=pool_type) return output + @layer + def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=None): + return self.pool('max', input, k_h, k_w, s_h, s_w, name, padding) + + @layer + def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=None): + return self.pool('avg', input, k_h, k_w, s_h, s_w, name, padding) + @layer def lrn(self, input, radius, alpha, beta, name, bias=1.0): - raise Exception('lrn() not implemented yet') + fluid = import_fluid() + output = fluid.layers.lrn(input=input, \ + n=radius, k=bias, alpha=alpha, beta=beta, name=name) + return output @layer def concat(self, inputs, axis, name): @@ -228,7 +254,7 @@ class Network(object): @layer def softmax(self, input, name): fluid = import_fluid() - output = fluid.layers.softmax(x=input, name=name) + output = fluid.layers.softmax(input) return output @layer @@ -256,5 +282,11 @@ class Network(object): return output @layer - def dropout(self, input, keep_prob, name): - raise Exception('dropout() not implemented yet') + def dropout(self, input, drop_prob, name, is_test=True): + fluid = import_fluid() + output = fluid.layers.dropout( + input, + dropout_prob=drop_prob, + is_test=is_test, + name=name) + return output diff --git a/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py b/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py index 92b9d32a3a755d8e6a2a8739cc3f42f9c8564b40..751112b61e125de99a80127db10164b75b91f7c8 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py +++ b/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py @@ -133,7 +133,7 @@ class TensorFlowMapper(NodeMapper): # We'll account for that here. alpha = params.alpha / float(params.local_size) return TensorFlowNode('lrn', - int(params.local_size / 2), alpha, params.beta) + params.local_size, alpha, params.beta) def map_concat(self, node): return TensorFlowNode('concat', node.parameters.axis) @@ -191,22 +191,33 @@ class TensorFlowEmitter(object): def emit_setup_def(self): return self.statement('def setup(self):') - def emit_convert_def(self, input_nodes): - def data_layer_def(name, shape, dtype=None): - if dtype is None: - dtype = 'float32' + def emit_shape_def(self, input_nodes): + self.outdent() + func_def = self.statement('@classmethod') + func_def += self.statement('def input_shapes(cls):') + self.indent() - layer_var = name + '_layer' - shape = [str(s) for s in shape[1:]] - layer_def = '%s = fluid.layers.data(name="%s", shape=[%s], dtype="%s")'\ - % (layer_var, name, ','.join(shape), dtype) - return layer_var, layer_def + input_shapes = {} + for n in input_nodes: + name = n.name + output_shape = n.output_shape + shape = [str(s) for s in output_shape[1:]] + input_shapes[name] = ', '.join(shape) + input_shapes = ['"%s": [%s]' % (n, l) for n, l in input_shapes.items()] + shape_str = ','.join(input_shapes) + func_def += self.statement('return {%s}' % (shape_str)) + return '\n\n' + func_def + def emit_convert_def(self, input_nodes): codes = [] inputs = {} + codes.append('shapes = cls.input_shapes()') for n in input_nodes: name = n.name - layer_var, layer_def = data_layer_def(n.name, n.output_shape) + layer_var = name + '_layer' + layer_def = '%s = fluid.layers.data(name="%s", shape=shapes["%s"],'\ + ' dtype="float32")' % (layer_var, name, name) + #layer_var, layer_def = data_layer_def(n.name, n.output_shape) codes.append(layer_def) inputs[name] = layer_var @@ -229,7 +240,7 @@ class TensorFlowEmitter(object): func_def += self.statement('import paddle.v2.fluid as fluid') for l in codes: func_def += self.statement(l) - return '\n\n' + func_def + return '\n' + func_def def emit_main_def(self, name): if name is None: @@ -273,6 +284,7 @@ class TensorFlowEmitter(object): b += self.emit_node(node) blocks.append(b[:-1]) s = s + '\n\n'.join(blocks) + s += self.emit_shape_def(input_nodes) s += self.emit_convert_def(input_nodes) s += self.emit_main_def(name) return s