From d300e5e437525ee11face24fc4c93207348a0fa8 Mon Sep 17 00:00:00 2001 From: walloollaw <37680514+walloollaw@users.noreply.github.com> Date: Fri, 13 Apr 2018 19:54:37 +0800 Subject: [PATCH] Update readme for caffe2fluid (#839) * fix code style problems * fix bug when loading fluid model * fix code style problem in brain.py * update readme.md for caffe2fluid * fix code style and add save_inference_model in caffe2fluid --- .../caffe2fluid/README.md | 31 ++-- .../caffe2fluid/examples/imagenet/README.md | 37 ++++- .../caffe2fluid/examples/imagenet/diff.sh | 0 .../caffe2fluid/examples/imagenet/infer.py | 143 ++++++++++++++---- .../caffe2fluid/examples/imagenet/run.sh | 2 +- .../caffe2fluid/examples/mnist/run.sh | 0 .../caffe2fluid/kaffe/paddle/transformer.py | 22 ++- .../caffe2fluid/proto/compile.sh | 0 8 files changed, 188 insertions(+), 47 deletions(-) mode change 100644 => 100755 fluid/image_classification/caffe2fluid/examples/imagenet/diff.sh mode change 100644 => 100755 fluid/image_classification/caffe2fluid/examples/imagenet/run.sh mode change 100644 => 100755 fluid/image_classification/caffe2fluid/examples/mnist/run.sh mode change 100644 => 100755 fluid/image_classification/caffe2fluid/proto/compile.sh diff --git a/fluid/image_classification/caffe2fluid/README.md b/fluid/image_classification/caffe2fluid/README.md index 6aba34b9..64f6b9cf 100644 --- a/fluid/image_classification/caffe2fluid/README.md +++ b/fluid/image_classification/caffe2fluid/README.md @@ -2,20 +2,31 @@ This tool is used to convert a Caffe model to Fluid model ### Howto -1, Prepare caffepb.py in ./proto if your python has no 'pycaffe' module, two options provided here: +1. Prepare caffepb.py in ./proto if your python has no 'pycaffe' module, two options provided here: +- Generate pycaffe from caffe.proto +
bash ./proto/compile.sh
- 1) generate it from caffe.proto using protoc - bash ./proto/compile.sh +- download one from github directly +
cd proto/ && wget https://github.com/ethereon/caffe-tensorflow/blob/master/kaffe/caffe/caffepb.py
+
- 2) download one from github directly - cd proto/ && wget https://github.com/ethereon/caffe-tensorflow/blob/master/kaffe/caffe/caffepb.py +2. Convert the Caffe model to Fluid model +- generate fluid code and weight file +
python convert.py alexnet.prototxt \
+        --caffemodel alexnet.caffemodel \
+        --data-output-path alexnet.npy \
+        --code-output-path alexnet.py
+
-2, Convert the caffe model using 'convert.py' which will generate a python script and a weight(in .npy) file +- save weights as fluid model file +
python alexnet.py alexnet.npy ./fluid_model
+
-3, Use the converted model to predict - - see more detail info in 'examples/xxx' +3. Use the converted model to infer +- see more details in '*examples/imagenet/run.sh*' +4. compare the inference results with caffe +- see more details in '*examples/imagenet/diff.sh*' ### Tested models - Lenet @@ -33,4 +44,4 @@ This tool is used to convert a Caffe model to Fluid model [model addr](https://github.com/BVLC/caffe/tree/master/models/bvlc_alexnet) ### Notes -Some of this code come from here: https://github.com/ethereon/caffe-tensorflow +Some of this code come from here: [caffe-tensorflow](https://github.com/ethereon/caffe-tensorflow) diff --git a/fluid/image_classification/caffe2fluid/examples/imagenet/README.md b/fluid/image_classification/caffe2fluid/examples/imagenet/README.md index b8205085..b9cf1941 100644 --- a/fluid/image_classification/caffe2fluid/examples/imagenet/README.md +++ b/fluid/image_classification/caffe2fluid/examples/imagenet/README.md @@ -1,10 +1,37 @@ -a demo to show converting caffe models on 'imagenet' using caffe2fluid +A demo to show converting caffe models on 'imagenet' using caffe2fluid --- # How to use -1. prepare python environment -2. download caffe model to "models.caffe/xxx" which contains "xxx.caffemodel" and "xxx.prototxt" -3. run the tool - eg: bash ./run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50 +1. Prepare python environment + +2. Download caffe model to "models.caffe/xxx" which contains "xxx.caffemodel" and "xxx.prototxt" + +3. Convert the Caffe model to Fluid model + - generate fluid code and weight file +
python convert.py alexnet.prototxt \
+        --caffemodel alexnet.caffemodel \
+        --data-output-path alexnet.npy \
+        --code-output-path alexnet.py
+    
+ + - save weights as fluid model file +
python alexnet.py alexnet.npy ./fluid_model
+    
+ +4. Do inference +
python infer.py infer ./fluid_mode data/65.jpeg
+
+ +5. convert model and do inference together +
bash ./run.sh alexnet ./models.caffe/alexnet ./models/alexnet
+
+ The Caffe model is stored in './models.caffe/alexnet/alexnet.prototxt|caffemodel' + and the Fluid model will be save in './models/alexnet/alexnet.py|npy' + +6. test the difference with caffe's results(need pycaffe installed) +
bash ./diff.sh resnet
+
+Make sure your caffemodel stored in './models.caffe/resnet'. +The results will be stored in './results/resnet.paddle|caffe' diff --git a/fluid/image_classification/caffe2fluid/examples/imagenet/diff.sh b/fluid/image_classification/caffe2fluid/examples/imagenet/diff.sh old mode 100644 new mode 100755 diff --git a/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py b/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py index bb75caa9..099c0abb 100644 --- a/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py +++ b/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py @@ -59,12 +59,12 @@ def build_model(net_file, net_name): inputs_dict = MyNet.input_shapes() input_name = inputs_dict.keys()[0] input_shape = inputs_dict[input_name] - images = fluid.layers.data(name='image', shape=input_shape, dtype='float32') + images = fluid.layers.data( + name=input_name, shape=input_shape, dtype='float32') #label = fluid.layers.data(name='label', shape=[1], dtype='int64') net = MyNet({input_name: images}) - input_shape = MyNet.input_shapes()[input_name] - return net, input_shape + return net, inputs_dict def dump_results(results, names, root): @@ -78,26 +78,27 @@ def dump_results(results, names, root): np.save(filename + '.npy', res) -def infer(net_file, net_name, model_file, imgfile, debug=True): - """ do inference using a model which consist 'xxx.py' and 'xxx.npy' +def load_model(exe, place, net_file, net_name, net_weight, debug): + """ load model using xxxnet.py and xxxnet.npy """ - fluid = import_fluid() #1, build model - net, input_shape = build_model(net_file, net_name) + net, input_map = build_model(net_file, net_name) + feed_names = input_map.keys() + feed_shapes = [v for k, v in input_map.items()] + prediction = net.get_output() #2, load weights for this model - place = fluid.CPUPlace() - exe = fluid.Executor(place) startup_program = fluid.default_startup_program() exe.run(startup_program) - if model_file.find('.npy') > 0: - net.load(data_path=model_file, exe=exe, place=place) + #place = fluid.CPUPlace() + if net_weight.find('.npy') > 0: + net.load(data_path=net_weight, exe=exe, place=place) else: - net.load(data_path=model_file, exe=exe) + raise ValueError('not found weight file') #3, test this model test_program = fluid.default_main_program().clone() @@ -111,10 +112,75 @@ def infer(net_file, net_name, model_file, imgfile, debug=True): fetch_list_var.append(v) fetch_list_name.append(k) + return { + 'program': test_program, + 'feed_names': feed_names, + 'fetch_vars': fetch_list_var, + 'fetch_names': fetch_list_name, + 'feed_shapes': feed_shapes + } + + +def get_shape(fluid, program, name): + for var in program.list_vars(): + if var.name == 'data': + return list(var.shape[1:]) + + raise ValueError('not found shape for input layer[%s], ' + 'you can specify by yourself' % (name)) + + +def load_inference_model(dirname, exe): + """ load fluid's inference model + """ + fluid = import_fluid() + model_fn = 'model' + params_fn = 'params' + if os.path.exists(os.path.join(dirname, model_fn)) \ + and os.path.exists(os.path.join(dirname, params_fn)): + program, feed_names, fetch_targets = fluid.io.load_inference_model(\ + dirname, exe, model_fn, params_fn) + else: + raise ValueError('not found model files in direcotry[%s]' % (dirname)) + + #print fluid.global_scope().find_var(feed_names[0]) + input_shape = get_shape(fluid, program, feed_names[0]) + feed_shapes = [input_shape] + + return program, feed_names, fetch_targets, feed_shapes + + +def infer(model_path, imgfile, net_file=None, net_name=None, debug=True): + """ do inference using a model which consist 'xxx.py' and 'xxx.npy' + """ + + fluid = import_fluid() + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + try: + ret = load_inference_model(model_path, exe) + program, feed_names, fetch_targets, feed_shapes = ret + debug = False + print('found a inference model for fluid') + except ValueError as e: + pass + print('try to load model using net file and weight file') + net_weight = model_path + ret = load_model(exe, place, net_file, net_name, net_weight, debug) + program = ret['program'] + feed_names = ret['feed_names'] + fetch_targets = ret['fetch_vars'] + fetch_list_name = ret['fetch_names'] + feed_shapes = ret['feed_shapes'] + + input_name = feed_names[0] + input_shape = feed_shapes[0] + np_images = load_data(imgfile, input_shape) - results = exe.run(program=test_program, - feed={'image': np_images}, - fetch_list=fetch_list_var) + results = exe.run(program=program, + feed={input_name: np_images}, + fetch_list=fetch_targets) if debug is True: dump_path = 'results.paddle' @@ -122,7 +188,7 @@ def infer(net_file, net_name, model_file, imgfile, debug=True): print('all result of layers dumped to [%s]' % (dump_path)) else: result = results[0] - print('predicted class:', np.argmax(result)) + print('succeed infer with results[class:%d]' % (np.argmax(result))) return 0 @@ -167,9 +233,12 @@ if __name__ == "__main__": weight_file = 'models/resnet50/resnet50.npy' datafile = 'data/65.jpeg' net_name = 'ResNet50' + model_file = 'models/resnet50/fluid' - argc = len(sys.argv) - if sys.argv[1] == 'caffe': + ret = None + if len(sys.argv) <= 2: + pass + elif sys.argv[1] == 'caffe': if len(sys.argv) != 5: print('usage:') print('\tpython %s caffe [prototxt] [caffemodel] [datafile]' % @@ -178,18 +247,34 @@ if __name__ == "__main__": prototxt = sys.argv[2] caffemodel = sys.argv[3] datafile = sys.argv[4] - sys.exit(caffe_infer(prototxt, caffemodel, datafile)) - elif argc == 5: - net_file = sys.argv[1] - weight_file = sys.argv[2] + ret = caffe_infer(prototxt, caffemodel, datafile) + elif sys.argv[1] == 'infer': + if len(sys.argv) != 4: + print('usage:') + print('\tpython %s infer [fluid_model] [datafile]' % (sys.argv[0])) + sys.exit(1) + model_path = sys.argv[2] datafile = sys.argv[3] - net_name = sys.argv[4] - elif argc > 1: + ret = infer(model_path, datafile) + elif sys.argv[1] == 'dump': + if len(sys.argv) != 6: + print('usage:') + print('\tpython %s dump [net_file] [weight_file] [datafile] [net_name]' \ + % (sys.argv[0])) + print('\teg:python dump %s %s %s %s %s' % (sys.argv[0],\ + net_file, weight_file, datafile, net_name)) + sys.exit(1) + + net_file = sys.argv[2] + weight_file = sys.argv[3] + datafile = sys.argv[4] + net_name = sys.argv[5] + ret = infer(weight_file, datafile, net_file, net_name) + + if ret is None: print('usage:') - print('\tpython %s [net_file] [weight_file] [datafile] [net_name]' % - (sys.argv[0])) - print('\teg:python %s %s %s %s %s' % (sys.argv[0], net_file, - weight_file, datafile, net_name)) + print(' python %s [infer] [fluid_model] [imgfile]' % (sys.argv[0])) + print(' eg:python %s infer %s %s' % (sys.argv[0], model_file, datafile)) sys.exit(1) - infer(net_file, net_name, weight_file, datafile) + sys.exit(ret) diff --git a/fluid/image_classification/caffe2fluid/examples/imagenet/run.sh b/fluid/image_classification/caffe2fluid/examples/imagenet/run.sh old mode 100644 new mode 100755 index ff3cc4ac..2f0a0ba0 --- a/fluid/image_classification/caffe2fluid/examples/imagenet/run.sh +++ b/fluid/image_classification/caffe2fluid/examples/imagenet/run.sh @@ -71,7 +71,7 @@ if [[ -z $only_convert ]];then if [[ -z $net_name ]];then net_name="MyNet" fi - $PYTHON ./infer.py $net_file $weight_file $imgfile $net_name + $PYTHON ./infer.py dump $net_file $weight_file $imgfile $net_name ret=$? fi exit $ret diff --git a/fluid/image_classification/caffe2fluid/examples/mnist/run.sh b/fluid/image_classification/caffe2fluid/examples/mnist/run.sh old mode 100644 new mode 100755 diff --git a/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py b/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py index 36975299..20155e99 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py +++ b/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py @@ -216,7 +216,10 @@ class TensorFlowEmitter(object): def emit_convert_def(self, input_nodes): codes = [] inputs = {} + #codes.append('shapes = cls.input_shapes()') codes.append('shapes = cls.input_shapes()') + codes.append('input_name = shapes.keys()[0]') + codes.append('input_shape = shapes[input_name]') for n in input_nodes: name = n.name layer_var = name + '_layer' @@ -235,8 +238,14 @@ class TensorFlowEmitter(object): codes.append("exe = fluid.Executor(place)") codes.append("exe.run(fluid.default_startup_program())") codes.append("net.load(data_path=npy_model, exe=exe, place=place)") + codes.append("output_vars = [net.get_output()]") + codes.append("fluid.io.save_inference_model(" \ + "fluid_path, [input_name],output_vars," \ + "exe, main_program=None, model_filename='model'," \ + "params_filename='params')") codes.append( - "fluid.io.save_persistables(executor=exe, dirname=fluid_path)") + "print('save fluid model as [model] and [params] in directory [%s]' % (fluid_path))" + ) self.outdent() func_def = self.statement('@classmethod') @@ -254,8 +263,17 @@ class TensorFlowEmitter(object): self.prefix = '' main_def = self.statement('if __name__ == "__main__":') self.indent() - main_def += self.statement("#usage: python xxxnet.py xxx.npy ./model\n") + main_def += self.statement( + "#usage: save as an inference model for online service\n") main_def += self.statement("import sys") + main_def += self.statement("if len(sys.argv) != 3:") + self.indent() + main_def += self.statement("print('usage:')") + main_def += self.statement( + "print('\tpython %s [xxxnet.npy] [save_dir]' % (sys.argv[0]))") + main_def += self.statement("exit(1)") + + self.outdent() main_def += self.statement("npy_weight = sys.argv[1]") main_def += self.statement("fluid_model = sys.argv[2]") main_def += self.statement("%s.convert(npy_weight, fluid_model)" % diff --git a/fluid/image_classification/caffe2fluid/proto/compile.sh b/fluid/image_classification/caffe2fluid/proto/compile.sh old mode 100644 new mode 100755 -- GitLab