diff --git a/fluid/DeepASR/examples/aishell/prepare_data.sh b/fluid/DeepASR/examples/aishell/prepare_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..d2c051c4d9ea10547f5ba4cc20213f430bf6dfce --- /dev/null +++ b/fluid/DeepASR/examples/aishell/prepare_data.sh @@ -0,0 +1,37 @@ +data_dir=~/.cache/paddle/dataset/speech/deep_asr_data/aishell +data_url='http://deep-asr-data.gz.bcebos.com/aishell_data.tar.gz' +lst_url='http://deep-asr-data.gz.bcebos.com/aishell_lst.tar.gz' +md5=e017d858d9e509c8a84b73f673f08b9a + +if [ ! -e $data_dir ]; then + mkdir -p $data_dir +fi + +if [ ! -e $data_dir/aishell_data.tar.gz ]; then + echo "Download $data_dir/aishell_data.tar.gz ..." + wget -c -P $data_dir $data_url +else + echo "Skip downloading for $data_dir/aishell_data.tar.gz has already existed!" +fi + +echo "Checking md5 sum ..." +md5sum_tmp=`md5sum $data_dir/aishell_data.tar.gz | cut -d ' ' -f1` + +if [ $md5sum_tmp != $md5 ]; then + echo "Md5sum check failed, please remove and redownload " + "$data_dir/aishell_data.tar.gz" + exit 1 +fi + +echo "Untar aishell_data.tar.gz ..." +tar xzf $data_dir/aishell_data.tar.gz -C $data_dir + +if [ ! -e data ]; then + mkdir data +fi + +echo "Download and untar lst files ..." +wget -c -P data $lst_url +tar xvf data/aishell_lst.tar.gz -C data + +ln -s $data_dir data/aishell diff --git a/fluid/DeepASR/examples/aishell/train.sh b/fluid/DeepASR/examples/aishell/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..41c0df2cd4985ae555f70554f27ff0dde8cb0cbe --- /dev/null +++ b/fluid/DeepASR/examples/aishell/train.sh @@ -0,0 +1,13 @@ +export CUDA_VISIBLE_DEVICES=2,3,4,5 +python -u ../../train.py --train_feature_lst data/train_feature.lst \ + --train_label_lst data/train_label.lst \ + --val_feature_lst data/val_feature.lst \ + --val_label_lst data/val_label.lst \ + --mean_var data/aishell/global_mean_var \ + --checkpoints checkpoints \ + --frame_dim 2640 \ + --class_num 101 \ + --infer_models '' \ + --batch_size 128 \ + --learning_rate 0.00016 \ + --parallel diff --git a/fluid/DeepASR/infer_by_ckpt.py b/fluid/DeepASR/infer_by_ckpt.py index bf6093acb8e14ec926d1aefb759207905e468f8d..4a4073c02279bfd74b8ce31d0877a5338400d93b 100644 --- a/fluid/DeepASR/infer_by_ckpt.py +++ b/fluid/DeepASR/infer_by_ckpt.py @@ -17,6 +17,7 @@ from decoder.post_decode_faster import Decoder from data_utils.util import lodtensor_to_ndarray from model_utils.model import stacked_lstmp_model from data_utils.util import split_infer_result +from tools.error_rate import char_errors def parse_args(): @@ -86,6 +87,11 @@ def parse_args(): type=str, default='data/infer_label.lst', help='The label list path for inference. (default: %(default)s)') + parser.add_argument( + '--ref_txt', + type=str, + default='data/text.test', + help='The reference text for decoding. (default: %(default)s)') parser.add_argument( '--checkpoint', type=str, @@ -111,6 +117,11 @@ def parse_args(): type=float, default=0.2, help="Scaling factor for acoustic likelihoods. (default: %(default)f)") + parser.add_argument( + '--target_trans', + type=str, + default="./decoder/target_trans.txt", + help="The path to target transcription. (default: %(default)s)") args = parser.parse_args() return args @@ -122,6 +133,18 @@ def print_arguments(args): print('------------------------------------------------') +def get_trg_trans(args): + trans_dict = {} + with open(args.target_trans) as trg_trans: + line = trg_trans.readline() + while line: + items = line.strip().split() + key = items[0] + trans_dict[key] = ''.join(items[1:]) + line = trg_trans.readline() + return trans_dict + + def infer_from_ckpt(args): """Inference by using checkpoint.""" @@ -145,6 +168,7 @@ def infer_from_ckpt(args): exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) + trg_trans = get_trg_trans(args) # load checkpoint. fluid.io.load_persistables(exe, args.checkpoint) @@ -166,11 +190,12 @@ def infer_from_ckpt(args): args.infer_label_lst) infer_data_reader.set_transformers(ltrans) infer_costs, infer_accs = [], [] + total_edit_dist, total_ref_len = 0.0, 0 for batch_id, batch_data in enumerate( infer_data_reader.batch_iterator(args.batch_size, args.minimum_batch_size)): # load_data - (features, labels, lod) = batch_data + (features, labels, lod, name_lst) = batch_data feature_t.set(features, place) feature_t.set_lod([lod]) label_t.set(labels, place) @@ -186,11 +211,19 @@ def infer_from_ckpt(args): probs, lod = lodtensor_to_ndarray(results[0]) infer_batch = split_infer_result(probs, lod) - for index, sample in enumerate(infer_batch): - key = "utter#%d" % (batch_id * args.batch_size + index) - print(key, ": ", decoder.decode(key, sample).encode("utf8"), "\n") - print(np.mean(infer_costs), np.mean(infer_accs)) + for index, sample in enumerate(infer_batch): + key = name_lst[index] + ref = trg_trans[key] + hyp = decoder.decode(key, sample) + edit_dist, ref_len = char_errors(ref.decode("utf8"), hyp) + total_edit_dist += edit_dist + total_ref_len += ref_len + print(key + "|Ref:", ref) + print(key + "|Hyp:", hyp.encode("utf8")) + print("Instance CER: ", edit_dist / ref_len) + + print("Total CER = %f" % (total_edit_dist / total_ref_len)) if __name__ == '__main__': diff --git a/fluid/DeepASR/tools/error_rate.py b/fluid/DeepASR/tools/error_rate.py new file mode 100644 index 0000000000000000000000000000000000000000..215ad39d24a551879d0fd8d4c8892161a0708370 --- /dev/null +++ b/fluid/DeepASR/tools/error_rate.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- +"""This module provides functions to calculate error rate in different level. +e.g. wer for word-level, cer for char-level. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +def _levenshtein_distance(ref, hyp): + """Levenshtein distance is a string metric for measuring the difference + between two sequences. Informally, the levenshtein disctance is defined as + the minimum number of single-character edits (substitutions, insertions or + deletions) required to change one word into the other. We can naturally + extend the edits to word level when calculate levenshtein disctance for + two sentences. + """ + m = len(ref) + n = len(hyp) + + # special case + if ref == hyp: + return 0 + if m == 0: + return n + if n == 0: + return m + + if m < n: + ref, hyp = hyp, ref + m, n = n, m + + # use O(min(m, n)) space + distance = np.zeros((2, n + 1), dtype=np.int32) + + # initialize distance matrix + for j in xrange(n + 1): + distance[0][j] = j + + # calculate levenshtein distance + for i in xrange(1, m + 1): + prev_row_idx = (i - 1) % 2 + cur_row_idx = i % 2 + distance[cur_row_idx][0] = i + for j in xrange(1, n + 1): + if ref[i - 1] == hyp[j - 1]: + distance[cur_row_idx][j] = distance[prev_row_idx][j - 1] + else: + s_num = distance[prev_row_idx][j - 1] + 1 + i_num = distance[cur_row_idx][j - 1] + 1 + d_num = distance[prev_row_idx][j] + 1 + distance[cur_row_idx][j] = min(s_num, i_num, d_num) + + return distance[m % 2][n] + + +def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '): + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in word-level. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param delimiter: Delimiter of input sentences. + :type delimiter: char + :return: Levenshtein distance and word number of reference sentence. + :rtype: list + """ + if ignore_case == True: + reference = reference.lower() + hypothesis = hypothesis.lower() + + ref_words = filter(None, reference.split(delimiter)) + hyp_words = filter(None, hypothesis.split(delimiter)) + + edit_distance = _levenshtein_distance(ref_words, hyp_words) + return float(edit_distance), len(ref_words) + + +def char_errors(reference, hypothesis, ignore_case=False, remove_space=False): + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in char-level. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param remove_space: Whether remove internal space characters + :type remove_space: bool + :return: Levenshtein distance and length of reference sentence. + :rtype: list + """ + if ignore_case == True: + reference = reference.lower() + hypothesis = hypothesis.lower() + + join_char = ' ' + if remove_space == True: + join_char = '' + + reference = join_char.join(filter(None, reference.split(' '))) + hypothesis = join_char.join(filter(None, hypothesis.split(' '))) + + edit_distance = _levenshtein_distance(reference, hypothesis) + return float(edit_distance), len(reference) + + +def wer(reference, hypothesis, ignore_case=False, delimiter=' '): + """Calculate word error rate (WER). WER compares reference text and + hypothesis text in word-level. WER is defined as: + .. math:: + WER = (Sw + Dw + Iw) / Nw + where + .. code-block:: text + Sw is the number of words subsituted, + Dw is the number of words deleted, + Iw is the number of words inserted, + Nw is the number of words in the reference + We can use levenshtein distance to calculate WER. Please draw an attention + that empty items will be removed when splitting sentences by delimiter. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param delimiter: Delimiter of input sentences. + :type delimiter: char + :return: Word error rate. + :rtype: float + :raises ValueError: If word number of reference is zero. + """ + edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, + delimiter) + + if ref_len == 0: + raise ValueError("Reference's word number should be greater than 0.") + + wer = float(edit_distance) / ref_len + return wer + + +def cer(reference, hypothesis, ignore_case=False, remove_space=False): + """Calculate charactor error rate (CER). CER compares reference text and + hypothesis text in char-level. CER is defined as: + .. math:: + CER = (Sc + Dc + Ic) / Nc + where + .. code-block:: text + Sc is the number of characters substituted, + Dc is the number of characters deleted, + Ic is the number of characters inserted + Nc is the number of characters in the reference + We can use levenshtein distance to calculate CER. Chinese input should be + encoded to unicode. Please draw an attention that the leading and tailing + space characters will be truncated and multiple consecutive space + characters in a sentence will be replaced by one space character. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param remove_space: Whether remove internal space characters + :type remove_space: bool + :return: Character error rate. + :rtype: float + :raises ValueError: If the reference length is zero. + """ + edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, + remove_space) + + if ref_len == 0: + raise ValueError("Length of reference should be greater than 0.") + + cer = float(edit_distance) / ref_len + return cer diff --git a/fluid/image_classification/caffe2fluid/README.md b/fluid/image_classification/caffe2fluid/README.md index 6aba34b9cafbd87b3474575fcbcee65819769c2f..9a6daad90222ab036cac896a66e50f273deac3d7 100644 --- a/fluid/image_classification/caffe2fluid/README.md +++ b/fluid/image_classification/caffe2fluid/README.md @@ -1,24 +1,64 @@ ### Caffe2Fluid -This tool is used to convert a Caffe model to Fluid model +This tool is used to convert a Caffe model to a Fluid model -### Howto -1, Prepare caffepb.py in ./proto if your python has no 'pycaffe' module, two options provided here: - - 1) generate it from caffe.proto using protoc +### HowTo +1. Prepare caffepb.py in ./proto if your python has no 'pycaffe' module, two options provided here: + - Generate pycaffe from caffe.proto + ``` bash ./proto/compile.sh + ``` - 2) download one from github directly + - Download one from github directly + ``` cd proto/ && wget https://github.com/ethereon/caffe-tensorflow/blob/master/kaffe/caffe/caffepb.py + ``` + +2. Convert the Caffe model to Fluid model + - Generate fluid code and weight file + ``` + python convert.py alexnet.prototxt \ + --caffemodel alexnet.caffemodel \ + --data-output-path alexnet.npy \ + --code-output-path alexnet.py + ``` + + - Save weights as fluid model file + ``` + python alexnet.py alexnet.npy ./fluid #only infer the last layer's result + python alexnet.py alexnet.npy ./fluid fc8,prob #infer these 2 layer's result + ``` + +3. Use the converted model to infer + - See more details in '*examples/imagenet/run.sh*' + +4. Compare the inference results with caffe + - See more details in '*examples/imagenet/diff.sh*' + +### How to convert custom layer +1. Implement your custom layer in a file under '*kaffe/custom_layers*', eg: mylayer.py + - Implement ```shape_func(input_shape, [other_caffe_params])``` to calculate the output shape + - Implement ```layer_func(inputs, name, [other_caffe_params])``` to construct a fluid layer + - Register these two functions ```register(kind='MyType', shape=shape_func, layer=layer_func)``` + - Notes: more examples can be found in '*kaffe/custom_layers*' + +2. Add ```import mylayer``` to '*kaffe/custom_layers/\_\_init__.py*' -2, Convert the caffe model using 'convert.py' which will generate a python script and a weight(in .npy) file +3. Prepare your pycaffe as your customized version(same as previous env prepare) + - (option1) replace 'proto/caffe.proto' with your own caffe.proto and compile it + - (option2) change your pycaffe to the customized version -3, Use the converted model to predict +4. Convert the Caffe model to Fluid model - see more detail info in 'examples/xxx' +5. Set env $CAFFE2FLUID_CUSTOM_LAYERS to the parent directory of 'custom_layers' + ``` + export CAFFE2FLUID_CUSTOM_LAYERS=/path/to/caffe2fluid/kaffe + ``` +6. Use the converted model when loading model in 'xxxnet.py' and 'xxxnet.npy'(no need if model is already in 'fluid/model' and 'fluid/params') ### Tested models -- Lenet +- Lenet: +[model addr](https://github.com/ethereon/caffe-tensorflow/blob/master/examples/mnist) - ResNets:(ResNet-50, ResNet-101, ResNet-152) [model addr](https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777) @@ -33,4 +73,4 @@ This tool is used to convert a Caffe model to Fluid model [model addr](https://github.com/BVLC/caffe/tree/master/models/bvlc_alexnet) ### Notes -Some of this code come from here: https://github.com/ethereon/caffe-tensorflow +Some of this code come from here: [caffe-tensorflow](https://github.com/ethereon/caffe-tensorflow) diff --git a/fluid/image_classification/caffe2fluid/convert.py b/fluid/image_classification/caffe2fluid/convert.py index 379f1a26368c9ffa4a9f82dad499ad7114f942fc..b0252e3c03db3626696a3672971f0704461417e7 100755 --- a/fluid/image_classification/caffe2fluid/convert.py +++ b/fluid/image_classification/caffe2fluid/convert.py @@ -43,11 +43,17 @@ def convert(def_path, caffemodel_path, data_output_path, code_output_path, print_stderr('Saving source...') with open(code_output_path, 'wb') as src_out: src_out.write(transformer.transform_source()) + print_stderr('set env variable before using converted model '\ + 'if used custom_layers:') + custom_pk_path = os.path.dirname(os.path.abspath(__file__)) + custom_pk_path = os.path.join(custom_pk_path, 'kaffe') + print_stderr('export CAFFE2FLUID_CUSTOM_LAYERS=%s' % (custom_pk_path)) print_stderr('Done.') + return 0 except KaffeError as err: fatal_error('Error encountered: {}'.format(err)) - return 0 + return 1 def main(): diff --git a/fluid/image_classification/caffe2fluid/examples/imagenet/README.md b/fluid/image_classification/caffe2fluid/examples/imagenet/README.md index b82050859239be8804ddec8e2054edc38c4ac052..b9cf1941d29428c84c34df2a9ec00d7ae8e79014 100644 --- a/fluid/image_classification/caffe2fluid/examples/imagenet/README.md +++ b/fluid/image_classification/caffe2fluid/examples/imagenet/README.md @@ -1,10 +1,37 @@ -a demo to show converting caffe models on 'imagenet' using caffe2fluid +A demo to show converting caffe models on 'imagenet' using caffe2fluid --- # How to use -1. prepare python environment -2. download caffe model to "models.caffe/xxx" which contains "xxx.caffemodel" and "xxx.prototxt" -3. run the tool - eg: bash ./run.sh resnet50 ./models.caffe/resnet50 ./models/resnet50 +1. Prepare python environment + +2. Download caffe model to "models.caffe/xxx" which contains "xxx.caffemodel" and "xxx.prototxt" + +3. Convert the Caffe model to Fluid model + - generate fluid code and weight file +
python convert.py alexnet.prototxt \
+ --caffemodel alexnet.caffemodel \
+ --data-output-path alexnet.npy \
+ --code-output-path alexnet.py
+
+
+ - save weights as fluid model file
+ python alexnet.py alexnet.npy ./fluid_model
+
+
+4. Do inference
+ python infer.py infer ./fluid_mode data/65.jpeg
+
+
+5. convert model and do inference together
+ bash ./run.sh alexnet ./models.caffe/alexnet ./models/alexnet
+
+ The Caffe model is stored in './models.caffe/alexnet/alexnet.prototxt|caffemodel'
+ and the Fluid model will be save in './models/alexnet/alexnet.py|npy'
+
+6. test the difference with caffe's results(need pycaffe installed)
+ bash ./diff.sh resnet
+
+Make sure your caffemodel stored in './models.caffe/resnet'.
+The results will be stored in './results/resnet.paddle|caffe'
diff --git a/fluid/image_classification/caffe2fluid/examples/imagenet/diff.sh b/fluid/image_classification/caffe2fluid/examples/imagenet/diff.sh
old mode 100644
new mode 100755
diff --git a/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py b/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py
index bb75caa9e7364465042c5c88f471e8f6f5137237..d71a91ad7e731e4585ae4adfb44b0a1019260e0d 100644
--- a/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py
+++ b/fluid/image_classification/caffe2fluid/examples/imagenet/infer.py
@@ -43,7 +43,7 @@ def build_model(net_file, net_name):
(net_file, net_name))
net_path = os.path.dirname(net_file)
- module_name = os.path.basename(net_file).rstrip('.py')
+ module_name = os.path.splitext(os.path.basename(net_file))[0]
if net_path not in sys.path:
sys.path.insert(0, net_path)
@@ -51,7 +51,7 @@ def build_model(net_file, net_name):
m = __import__(module_name, fromlist=[net_name])
MyNet = getattr(m, net_name)
except Exception as e:
- print('failed to load module[%s]' % (module_name))
+ print('failed to load module[%s.%s]' % (module_name, net_name))
print(e)
return None
@@ -59,12 +59,12 @@ def build_model(net_file, net_name):
inputs_dict = MyNet.input_shapes()
input_name = inputs_dict.keys()[0]
input_shape = inputs_dict[input_name]
- images = fluid.layers.data(name='image', shape=input_shape, dtype='float32')
+ images = fluid.layers.data(
+ name=input_name, shape=input_shape, dtype='float32')
#label = fluid.layers.data(name='label', shape=[1], dtype='int64')
net = MyNet({input_name: images})
- input_shape = MyNet.input_shapes()[input_name]
- return net, input_shape
+ return net, inputs_dict
def dump_results(results, names, root):
@@ -78,26 +78,27 @@ def dump_results(results, names, root):
np.save(filename + '.npy', res)
-def infer(net_file, net_name, model_file, imgfile, debug=True):
- """ do inference using a model which consist 'xxx.py' and 'xxx.npy'
+def load_model(exe, place, net_file, net_name, net_weight, debug):
+ """ load model using xxxnet.py and xxxnet.npy
"""
-
fluid = import_fluid()
#1, build model
- net, input_shape = build_model(net_file, net_name)
+ net, input_map = build_model(net_file, net_name)
+ feed_names = input_map.keys()
+ feed_shapes = [v for k, v in input_map.items()]
+
prediction = net.get_output()
#2, load weights for this model
- place = fluid.CPUPlace()
- exe = fluid.Executor(place)
startup_program = fluid.default_startup_program()
exe.run(startup_program)
- if model_file.find('.npy') > 0:
- net.load(data_path=model_file, exe=exe, place=place)
+ #place = fluid.CPUPlace()
+ if net_weight.find('.npy') > 0:
+ net.load(data_path=net_weight, exe=exe, place=place)
else:
- net.load(data_path=model_file, exe=exe)
+ raise ValueError('not found weight file')
#3, test this model
test_program = fluid.default_main_program().clone()
@@ -111,10 +112,73 @@ def infer(net_file, net_name, model_file, imgfile, debug=True):
fetch_list_var.append(v)
fetch_list_name.append(k)
+ return {
+ 'program': test_program,
+ 'feed_names': feed_names,
+ 'fetch_vars': fetch_list_var,
+ 'fetch_names': fetch_list_name,
+ 'feed_shapes': feed_shapes
+ }
+
+
+def get_shape(fluid, program, name):
+ for var in program.list_vars():
+ if var.name == 'data':
+ return list(var.shape[1:])
+
+ raise ValueError('not found shape for input layer[%s], '
+ 'you can specify by yourself' % (name))
+
+
+def load_inference_model(dirname, exe):
+ """ load fluid's inference model
+ """
+ fluid = import_fluid()
+ model_fn = 'model'
+ params_fn = 'params'
+ if os.path.exists(os.path.join(dirname, model_fn)) \
+ and os.path.exists(os.path.join(dirname, params_fn)):
+ program, feed_names, fetch_targets = fluid.io.load_inference_model(\
+ dirname, exe, model_fn, params_fn)
+ else:
+ raise ValueError('not found model files in direcotry[%s]' % (dirname))
+
+ #print fluid.global_scope().find_var(feed_names[0])
+ input_shape = get_shape(fluid, program, feed_names[0])
+ feed_shapes = [input_shape]
+
+ return program, feed_names, fetch_targets, feed_shapes
+
+
+def infer(model_path, imgfile, net_file=None, net_name=None, debug=True):
+ """ do inference using a model which consist 'xxx.py' and 'xxx.npy'
+ """
+ fluid = import_fluid()
+
+ place = fluid.CPUPlace()
+ exe = fluid.Executor(place)
+ try:
+ ret = load_inference_model(model_path, exe)
+ program, feed_names, fetch_targets, feed_shapes = ret
+ debug = False
+ print('found a inference model for fluid')
+ except ValueError as e:
+ print('try to load model using net file and weight file')
+ net_weight = model_path
+ ret = load_model(exe, place, net_file, net_name, net_weight, debug)
+ program = ret['program']
+ feed_names = ret['feed_names']
+ fetch_targets = ret['fetch_vars']
+ fetch_list_name = ret['fetch_names']
+ feed_shapes = ret['feed_shapes']
+
+ input_name = feed_names[0]
+ input_shape = feed_shapes[0]
+
np_images = load_data(imgfile, input_shape)
- results = exe.run(program=test_program,
- feed={'image': np_images},
- fetch_list=fetch_list_var)
+ results = exe.run(program=program,
+ feed={input_name: np_images},
+ fetch_list=fetch_targets)
if debug is True:
dump_path = 'results.paddle'
@@ -122,7 +186,7 @@ def infer(net_file, net_name, model_file, imgfile, debug=True):
print('all result of layers dumped to [%s]' % (dump_path))
else:
result = results[0]
- print('predicted class:', np.argmax(result))
+ print('succeed infer with results[class:%d]' % (np.argmax(result)))
return 0
@@ -167,9 +231,12 @@ if __name__ == "__main__":
weight_file = 'models/resnet50/resnet50.npy'
datafile = 'data/65.jpeg'
net_name = 'ResNet50'
+ model_file = 'models/resnet50/fluid'
- argc = len(sys.argv)
- if sys.argv[1] == 'caffe':
+ ret = None
+ if len(sys.argv) <= 2:
+ pass
+ elif sys.argv[1] == 'caffe':
if len(sys.argv) != 5:
print('usage:')
print('\tpython %s caffe [prototxt] [caffemodel] [datafile]' %
@@ -178,18 +245,34 @@ if __name__ == "__main__":
prototxt = sys.argv[2]
caffemodel = sys.argv[3]
datafile = sys.argv[4]
- sys.exit(caffe_infer(prototxt, caffemodel, datafile))
- elif argc == 5:
- net_file = sys.argv[1]
- weight_file = sys.argv[2]
+ ret = caffe_infer(prototxt, caffemodel, datafile)
+ elif sys.argv[1] == 'infer':
+ if len(sys.argv) != 4:
+ print('usage:')
+ print('\tpython %s infer [fluid_model] [datafile]' % (sys.argv[0]))
+ sys.exit(1)
+ model_path = sys.argv[2]
datafile = sys.argv[3]
- net_name = sys.argv[4]
- elif argc > 1:
+ ret = infer(model_path, datafile)
+ elif sys.argv[1] == 'dump':
+ if len(sys.argv) != 6:
+ print('usage:')
+ print('\tpython %s dump [net_file] [weight_file] [datafile] [net_name]' \
+ % (sys.argv[0]))
+ print('\teg:python dump %s %s %s %s %s' % (sys.argv[0],\
+ net_file, weight_file, datafile, net_name))
+ sys.exit(1)
+
+ net_file = sys.argv[2]
+ weight_file = sys.argv[3]
+ datafile = sys.argv[4]
+ net_name = sys.argv[5]
+ ret = infer(weight_file, datafile, net_file, net_name)
+
+ if ret is None:
print('usage:')
- print('\tpython %s [net_file] [weight_file] [datafile] [net_name]' %
- (sys.argv[0]))
- print('\teg:python %s %s %s %s %s' % (sys.argv[0], net_file,
- weight_file, datafile, net_name))
+ print(' python %s [infer] [fluid_model] [imgfile]' % (sys.argv[0]))
+ print(' eg:python %s infer %s %s' % (sys.argv[0], model_file, datafile))
sys.exit(1)
- infer(net_file, net_name, weight_file, datafile)
+ sys.exit(ret)
diff --git a/fluid/image_classification/caffe2fluid/examples/imagenet/run.sh b/fluid/image_classification/caffe2fluid/examples/imagenet/run.sh
old mode 100644
new mode 100755
index ff3cc4ac44a8ccaeb0b33f1bcdbc46886fb7d7e9..0fdd56e4519bf726a8e5bc95559d1d9b47f14774
--- a/fluid/image_classification/caffe2fluid/examples/imagenet/run.sh
+++ b/fluid/image_classification/caffe2fluid/examples/imagenet/run.sh
@@ -67,11 +67,11 @@ if [[ -z $only_convert ]];then
imgfile="data/65.jpeg"
#FIX ME:
# only look the first line in prototxt file for the name of this network, maybe not correct
- net_name=`grep "name" $proto_file | head -n1 | perl -ne 'if(/^\s*name\s*:\s*\"([^\"]+)\"/){ print $1."\n";}'`
+ net_name=`grep "name" $proto_file | head -n1 | perl -ne 'if(/^name\s*:\s*\"([^\"]+)\"/){ print $1."\n";}'`
if [[ -z $net_name ]];then
net_name="MyNet"
fi
- $PYTHON ./infer.py $net_file $weight_file $imgfile $net_name
+ $PYTHON ./infer.py dump $net_file $weight_file $imgfile $net_name
ret=$?
fi
exit $ret
diff --git a/fluid/image_classification/caffe2fluid/examples/mnist/evaluate.py b/fluid/image_classification/caffe2fluid/examples/mnist/evaluate.py
index 5c86635d5a014262bdec40fe063915350c5fadb3..946fa943726b39c4e8e8dfce9f41c87a06ee1912 100644
--- a/fluid/image_classification/caffe2fluid/examples/mnist/evaluate.py
+++ b/fluid/image_classification/caffe2fluid/examples/mnist/evaluate.py
@@ -7,8 +7,8 @@
import sys
import os
import numpy as np
+import paddle.fluid as fluid
import paddle.v2 as paddle
-import paddle.v2.fluid as fluid
def test_model(exe, test_program, fetch_list, test_reader, feeder):
@@ -34,9 +34,6 @@ def evaluate(net_file, model_file):
from lenet import LeNet as MyNet
- with_gpu = False
- paddle.init(use_gpu=with_gpu)
-
#1, define network topology
images = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
@@ -45,7 +42,7 @@ def evaluate(net_file, model_file):
prediction = net.layers['prob']
acc = fluid.layers.accuracy(input=prediction, label=label)
- place = fluid.CUDAPlace(0) if with_gpu is True else fluid.CPUPlace()
+ place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
diff --git a/fluid/image_classification/caffe2fluid/examples/mnist/run.sh b/fluid/image_classification/caffe2fluid/examples/mnist/run.sh
old mode 100644
new mode 100755
diff --git a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/__init__.py b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2276c09c2c408f4c6e65264b4bde91429df53ca
--- /dev/null
+++ b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/__init__.py
@@ -0,0 +1,105 @@
+"""
+"""
+
+from .register import get_registered_layers
+#custom layer import begins
+
+import axpy
+import flatten
+import argmax
+
+#custom layer import ends
+
+custom_layers = get_registered_layers()
+
+
+def set_args(f, params):
+ """ set args for function 'f' using the parameters in node.layer.parameters
+
+ Args:
+ f (function): a python function object
+ params (object): a object contains attributes needed by f's arguments
+
+ Returns:
+ arg_names (list): a list of argument names
+ kwargs (dict): a dict contains needed arguments
+ """
+ argc = f.__code__.co_argcount
+ arg_list = f.__code__.co_varnames[0:argc]
+
+ kwargs = {}
+ for arg_name in arg_list:
+ try:
+ v = getattr(params, arg_name, None)
+ except Exception as e:
+ #maybe failed to extract caffe's parameters
+ v = None
+
+ if v is not None:
+ kwargs[arg_name] = v
+
+ return arg_list, kwargs
+
+
+def has_layer(kind):
+ """ test whether this layer exists in custom layer
+ """
+ return kind in custom_layers
+
+
+def compute_output_shape(kind, node):
+ assert kind in custom_layers, "layer[%s] not exist in custom layers" % (
+ kind)
+ shape_func = custom_layers[kind]['shape']
+
+ parents = node.parents
+ inputs = [list(p.output_shape) for p in parents]
+ arg_names, kwargs = set_args(shape_func, node.layer.parameters)
+
+ if len(inputs) == 1:
+ inputs = inputs[0]
+
+ return shape_func(inputs, **kwargs)
+
+
+def make_node(template, kind, node):
+ """ make a TensorFlowNode for custom layer which means construct
+ a piece of code to define a layer implemented in 'custom_layers'
+
+ Args:
+ @template (TensorFlowNode): a factory to new a instance of TensorFLowNode
+ @kind (str): type of custom layer
+ @node (graph.Node): a layer in the net
+
+ Returns:
+ instance of TensorFlowNode
+ """
+ assert kind in custom_layers, "layer[%s] not exist in custom layers" % (
+ kind)
+
+ layer_func = custom_layers[kind]['layer']
+
+ #construct arguments needed by custom layer function from node's parameters
+ arg_names, kwargs = set_args(layer_func, node.layer.parameters)
+
+ return template('custom_layer', kind, **kwargs)
+
+
+def make_custom_layer(kind, inputs, name, *args, **kwargs):
+ """ execute a custom layer which is implemented by users
+
+ Args:
+ @kind (str): type name of this layer
+ @inputs (vars): variable list created by fluid
+ @namme (str): name for this layer
+ @args (tuple): other positional arguments
+ @kwargs (dict): other kv arguments
+
+ Returns:
+ output (var): output variable for this layer
+ """
+ assert kind in custom_layers, "layer[%s] not exist in custom layers" % (
+ kind)
+
+ layer_func = custom_layers[kind]['layer']
+ return layer_func(inputs, name, *args, **kwargs)
diff --git a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/argmax.py b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/argmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d884f53a1027d091fe409632209a2d9a579f573
--- /dev/null
+++ b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/argmax.py
@@ -0,0 +1,71 @@
+""" a custom layer for 'argmax', maybe we should implement this in standard way.
+ more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/argmax.html
+"""
+from .register import register
+
+
+def import_fluid():
+ import paddle.fluid as fluid
+ return fluid
+
+
+def argmax_shape(input_shape, out_max_val=False, top_k=1, axis=-1):
+ """ calculate the output shape of this layer using input shape
+
+ Args:
+ @input_shape (list of num): a list of number which represents the input shape
+ @out_max_val (bool): parameter from caffe's ArgMax layer
+ @top_k (int): parameter from caffe's ArgMax layer
+ @axis (int): parameter from caffe's ArgMax layer
+
+ Returns:
+ @output_shape (list of num): a list of numbers represent the output shape
+ """
+ input_shape = list(input_shape)
+
+ if axis < 0:
+ axis += len(input_shape)
+
+ assert (axis + 1 == len(input_shape)
+ ), 'only can be applied on the last dimension[axis:%d, %s] now,'\
+ 'make sure you have set axis param in xxx.prototxt file' \
+ % (axis, str(input_shape))
+
+ output_shape = input_shape
+ output_shape[-1] = top_k
+ if out_max_val is True:
+ output_shape[-1] *= 2
+
+ return output_shape
+
+
+def argmax_layer(input, name, out_max_val=False, top_k=1, axis=-1):
+ """ build a layer of type 'ArgMax' using fluid
+
+ Args:
+ @input (variable): input fluid variable for this layer
+ @name (str): name for this layer
+ @out_max_val (bool): parameter from caffe's ArgMax layer
+ @top_k (int): parameter from caffe's ArgMax layer
+ @axis (int): parameter from caffe's ArgMax layer
+
+ Returns:
+ output (variable): output variable for this layer
+ """
+
+ fluid = import_fluid()
+
+ if axis < 0:
+ axis += len(input.shape)
+
+ topk_var, index_var = fluid.layers.topk(input=input, k=top_k)
+ if out_max_val is True:
+ index_var = fluid.layers.cast(index_var, dtype=topk_var.dtype)
+ output = fluid.layers.concat([index_var, topk_var], axis=axis)
+ else:
+ output = index_var
+
+ return output
+
+
+register(kind='ArgMax', shape=argmax_shape, layer=argmax_layer)
diff --git a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/axpy.py b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/axpy.py
new file mode 100644
index 0000000000000000000000000000000000000000..389bb7996e87b2813a7704ef5e0c14332f95ab08
--- /dev/null
+++ b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/axpy.py
@@ -0,0 +1,51 @@
+""" A custom layer for 'axpy' which receives 3 tensors and output 1 tensor.
+ the function performed is:(the mupltiplication and add are elementewise)
+ output = inputs[0] * inputs[1] + inputs[2]
+"""
+
+from .register import register
+
+
+def axpy_shape(input_shapes):
+ """ calculate the output shape of this layer using input shapes
+
+ Args:
+ @input_shapes (list of tuples): a list of input shapes
+
+ Returns:
+ @output_shape (list of num): a list of numbers represent the output shape
+ """
+ assert len(input_shapes) == 3, "not valid input shape for axpy layer"
+ assert len(input_shapes[0]) == len(input_shapes[1]), 'should have same dims'
+
+ output_shape = input_shapes[1]
+ assert (input_shapes[2] == output_shape),\
+ "shape not consistent for axpy[%s <--> %s]" \
+ % (str(output_shape), str(input_shapes[2]))
+
+ return output_shape
+
+
+def axpy_layer(inputs, name):
+ """ build a layer of type 'Axpy' using fluid
+
+ Args:
+ @inputs (list of variables): input fluid variables for this layer
+ @name (str): name for this layer
+
+ Returns:
+ output (variable): output variable for this layer
+ """
+ import paddle.fluid as fluid
+
+ assert len(inputs) == 3, "invalid inputs for axpy[%s]" % (name)
+ alpha = inputs[0]
+ x = inputs[1]
+ y = inputs[2]
+ output = fluid.layers.elementwise_mul(x, alpha, axis=0)
+ output = fluid.layers.elementwise_add(output, y)
+
+ return output
+
+
+register(kind='Axpy', shape=axpy_shape, layer=axpy_layer)
diff --git a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/flatten.py b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/flatten.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f7af4266f7fd4b7b6e8ee868f44f1b35f35cb00
--- /dev/null
+++ b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/flatten.py
@@ -0,0 +1,73 @@
+""" a custom layer for 'flatten', maybe we should implement this in standard way.
+ more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/flatten.html
+"""
+from .register import register
+
+
+def import_fluid():
+ import paddle.fluid as fluid
+ return fluid
+
+
+def flatten_shape(input_shape, axis=1, end_axis=-1):
+ """ calculate the output shape of this layer using input shape
+
+ Args:
+ @input_shape (list of num): a list of number which represents the input shape
+ @axis (int): parameter from caffe's Flatten layer
+ @end_axis (int): parameter from caffe's Flatten layer
+
+ Returns:
+ @output_shape (list of num): a list of numbers represent the output shape
+ """
+
+ start_axis = axis
+ end_axis = end_axis
+ input_shape = list(input_shape)
+ if start_axis < 0:
+ start_axis += len(input_shape)
+
+ if end_axis < 0:
+ end_axis += len(input_shape)
+
+ assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\
+ % (start_axis, end_axis)
+ output_shape = input_shape[0:start_axis]
+ flat_sz = reduce(lambda a, b: a * b, input_shape[start_axis:end_axis])
+ output_shape += [flat_sz]
+ output_shape += input_shape[end_axis:-1]
+
+ return output_shape
+
+
+def flatten_layer(input, name, axis=1, end_axis=-1):
+ """ build a layer of type 'Flatten' using fluid
+
+ Args:
+ @input (variable): input fluid variable for this layer
+ @name (str): name for this layer
+ @axis (int): parameter from caffe's Flatten layer
+ @end_axis (int): parameter from caffe's Flatten layer
+
+ Returns:
+ output (variable): output variable for this layer
+ """
+ fluid = import_fluid()
+
+ input_shape = list(input.shape)
+ dims = len(input_shape)
+ start_axis = axis if axis >= 0 else axis + dims
+ end_axis = end_axis if end_axis >= 0 else end_axis + dims
+
+ assert start_axis <= end_axis, 'invalid axis or end_axis params'
+ output_shape = input_shape[0:start_axis]
+ flat_sz = reduce(lambda a, b: a * b, input_shape[start_axis:end_axis])
+ output_shape += [flat_sz]
+ output_shape += input_shape[end_axis:-1]
+
+ output = fluid.layers.reshape(input, shape=output_shape, name=name)
+
+ return output
+
+
+register(kind='Flatten', shape=flatten_shape, layer=flatten_layer)
diff --git a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/register.py b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/register.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae806cd469cb763dd06bbe406abb2ced3419cffc
--- /dev/null
+++ b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/register.py
@@ -0,0 +1,37 @@
+""" this module provides 'register' for registering customized layers
+"""
+
+g_custom_layers = {}
+
+
+def register(kind, shape, layer):
+ """ register a custom layer or a list of custom layers
+
+ Args:
+ @kind (str or list): type name of the layer
+ @shape (function): a function to generate the shape of layer's output
+ @layer (function): a function to generate the shape of layer's output
+
+ Returns:
+ None
+ """
+ assert type(shape).__name__ == 'function', 'shape should be a function'
+ assert type(layer).__name__ == 'function', 'layer should be a function'
+
+ if type(kind) is str:
+ kind = [kind]
+ else:
+ assert type(
+ kind) is list, 'invalid param "kind" for register, not a list or str'
+
+ for k in kind:
+ assert type(
+ k) is str, 'invalid param "kind" for register, not a list of str'
+ assert k not in g_custom_layers, 'this type[%s] has already been registered' % (
+ k)
+ print('register layer[%s]' % (k))
+ g_custom_layers[k] = {'shape': shape, 'layer': layer}
+
+
+def get_registered_layers():
+ return g_custom_layers
diff --git a/fluid/image_classification/caffe2fluid/kaffe/graph.py b/fluid/image_classification/caffe2fluid/kaffe/graph.py
index c6fdada6e78c8fbeb98604033e4cb77995555ce9..6182a5352dac4746c64ebef0b3a886399dbd3d57 100644
--- a/fluid/image_classification/caffe2fluid/kaffe/graph.py
+++ b/fluid/image_classification/caffe2fluid/kaffe/graph.py
@@ -3,7 +3,7 @@ from google.protobuf import text_format
from .caffe import get_caffe_resolver
from .errors import KaffeError, print_stderr
from .layers import LayerAdapter, LayerType, NodeKind, NodeDispatch
-from .shapes import TensorShape
+from .shapes import make_tensor
class Node(object):
@@ -98,7 +98,7 @@ class Graph(object):
def compute_output_shapes(self):
sorted_nodes = self.topologically_sorted()
for node in sorted_nodes:
- node.output_shape = TensorShape(
+ node.output_shape = make_tensor(
*NodeKind.compute_output_shape(node))
def replaced(self, new_nodes):
@@ -111,6 +111,7 @@ class Graph(object):
if graph is None:
raise KaffeError('Transformer failed: {}'.format(transformer))
assert isinstance(graph, Graph)
+
return graph
def __contains__(self, key):
@@ -123,10 +124,18 @@ class Graph(object):
for node in self.topologically_sorted():
# If the node has learned parameters, display the first one's shape.
# In case of convolutions, this corresponds to the weights.
- data_shape = node.data[0].shape if node.data else '--'
- out_shape = node.output_shape or '--'
- s.append('{:<20} {:<30} {:>20} {:>20}'.format(
- node.kind, node.name, data_shape, tuple(out_shape)))
+ if node.data is None:
+ data_shape = '--'
+ out_shape = node.output_shape or '--'
+ s.append('{:<20} {:<30} {:>20} {:>20}'.format(
+ node.kind, node.name, data_shape, tuple(out_shape)))
+ else:
+ for d in node.data:
+ #data_shape = node.data[0].shape if node.data else '--'
+ data_shape = d.shape
+ out_shape = node.output_shape or '--'
+ s.append('{:<20} {:<30} {:>20} {:>20}'.format(
+ node.kind, node.name, data_shape, tuple(out_shape)))
return '\n'.join(s)
@@ -237,6 +246,7 @@ class GraphBuilder(object):
if (parent_node is None) or (parent_node == node):
parent_node = graph.get_node(input_name)
node.add_parent(parent_node)
+
if len(layer.top) > 1:
raise KaffeError('Multiple top nodes are not supported.')
diff --git a/fluid/image_classification/caffe2fluid/kaffe/layers.py b/fluid/image_classification/caffe2fluid/kaffe/layers.py
index f263407ab41458573f2df775f99202bed0e9d894..dcdd26040b6918d524f1d5ae58aa92f6da1a9550 100644
--- a/fluid/image_classification/caffe2fluid/kaffe/layers.py
+++ b/fluid/image_classification/caffe2fluid/kaffe/layers.py
@@ -2,6 +2,7 @@ import re
import numbers
from collections import namedtuple
+import custom_layers
from .shapes import *
LAYER_DESCRIPTORS = {
@@ -116,6 +117,9 @@ def get_v1_layer_map():
class NodeKind(LayerType):
@staticmethod
def map_raw_kind(kind):
+ if custom_layers.has_layer(kind):
+ return kind
+
if kind in LAYER_TYPES:
return kind
@@ -127,6 +131,9 @@ class NodeKind(LayerType):
@staticmethod
def compute_output_shape(node):
+ if custom_layers.has_layer(node.kind):
+ return custom_layers.compute_output_shape(node.kind, node)
+
try:
val = LAYER_DESCRIPTORS[node.kind](node)
return val
@@ -137,14 +144,13 @@ class NodeKind(LayerType):
class NodeDispatchError(KaffeError):
-
pass
class NodeDispatch(object):
@staticmethod
def get_handler_name(node_kind):
- if len(node_kind) <= 4:
+ if len(node_kind) <= 6:
# A catch-all for things like ReLU and tanh
return node_kind.lower()
# Convert from CamelCase to under_scored
@@ -152,6 +158,9 @@ class NodeDispatch(object):
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
def get_handler(self, node_kind, prefix):
+ if custom_layers.has_layer(node_kind):
+ return getattr(self, 'map_custom')
+
name = self.get_handler_name(node_kind)
name = '_'.join((prefix, name))
try:
@@ -174,8 +183,10 @@ class LayerAdapter(object):
try:
return getattr(self.layer, name)
except AttributeError:
+ print(dir(self.layer))
raise NodeDispatchError(
- 'Caffe parameters not found for layer kind: %s' % (self.kind))
+ 'Caffe parameters not found attr[%s] for layer kind[%s]' %
+ (name, self.kind))
@staticmethod
def get_kernel_value(scalar, repeated, idx, default=None):
diff --git a/fluid/image_classification/caffe2fluid/kaffe/net_template.py b/fluid/image_classification/caffe2fluid/kaffe/net_template.py
new file mode 100644
index 0000000000000000000000000000000000000000..e57caf97948a903b02a136a38b0a0b716ac49057
--- /dev/null
+++ b/fluid/image_classification/caffe2fluid/kaffe/net_template.py
@@ -0,0 +1,151 @@
+""" this module is used as a template for generating sub class of Network
+"""
+
+
+class MyNet(object):
+ ### automatically generated by caffe2fluid ###
+ inputs_info = "INPUTS_INFO"
+ custom_layers_path = "CAFFE2FLUID_CUSTOM_LAYERS"
+
+ def custom_layer_factory(self):
+ import os
+
+ pk_paths = []
+ default = os.path.dirname(os.path.abspath(__file__))
+ location = os.environ.get('CAFFE2FLUID_CUSTOM_LAYERS', default)
+ pk_name = 'custom_layers'
+ pk_dir = os.path.join(location, pk_name)
+ pk_paths.append((location, pk_dir))
+
+ location = MyNet.custom_layers_path
+ pk_dir = os.path.join(MyNet.custom_layers_path, pk_name)
+ pk_paths.append((location, pk_dir))
+
+ for loc, pk_dir in pk_paths:
+ if os.path.exists(pk_dir):
+ if loc not in sys.path:
+ sys.path.insert(0, loc)
+ break
+
+ try:
+ from custom_layers import make_custom_layer
+ return make_custom_layer
+ except Exception as e:
+ print('maybe you should set $CAFFE2FLUID_CUSTOM_LAYERS first')
+ raise e
+
+ @classmethod
+ def input_shapes(cls):
+ return cls.inputs_info
+
+ @classmethod
+ def convert(cls, npy_model, fluid_path, outputs=None):
+ fluid = import_fluid()
+ shapes = cls.input_shapes()
+ input_name = shapes.keys()[0]
+ feed_data = {}
+ for name, shape in shapes.items():
+ data_layer = fluid.layers.data(
+ name=name, shape=shape, dtype="float32")
+ feed_data[name] = data_layer
+
+ net = cls(feed_data)
+ place = fluid.CPUPlace()
+ exe = fluid.Executor(place)
+ exe.run(fluid.default_startup_program())
+ net.load(data_path=npy_model, exe=exe, place=place)
+ output_vars = []
+ if outputs is None:
+ output_vars.append(net.get_output())
+ else:
+ if type(outputs) is list:
+ for n in outputs:
+ assert n in net.layers, 'not found layer with this name[%s]' % (
+ n)
+ output_vars.append(net.layers[n])
+
+ fluid.io.save_inference_model(
+ fluid_path, [input_name],
+ output_vars,
+ exe,
+ main_program=None,
+ model_filename='model',
+ params_filename='params')
+ return 0
+
+
+def main():
+ """ a tool used to convert caffe model to fluid
+ """
+
+ import sys
+ import os
+ filename = os.path.splitext(os.path.basename(sys.argv[0]))[0]
+ if len(sys.argv) < 3:
+ print('usage:')
+ print(' python %s %s.npy [save_dir] [layer names seperated by comma]' \
+ % (sys.argv[0], filename))
+ print(' eg: python %s %s.npy ./fluid' % (sys.argv[0], filename))
+ print(' eg: python %s %s.npy ./fluid layer_name1,layer_name2' \
+ % (sys.argv[0], filename))
+ return 1
+
+ npy_weight = sys.argv[1]
+ fluid_model = sys.argv[2]
+ outputs = None
+ if len(sys.argv) >= 4:
+ outputs = sys.argv[3].split(',')
+
+ ret = MyNet.convert(npy_weight, fluid_model, outputs)
+ if ret == 0:
+ outputs = 'last output layer' if outputs is None else outputs
+ print('succeed to convert to fluid format with output layers[%s]'
+ ' in directory[%s]' % (outputs, fluid_model))
+ else:
+ print('failed to convert model to fluid format')
+
+ return ret
+
+
+def generate_net_code(net_name, inputs_info):
+ """ generate framework of a custom net code which represent a subclass of Network
+
+ Args:
+ @net_name (str): class name for this net
+ @inputs_info (str): a str which represents a dict, eg: '{"data": [3, 32, 32]}'
+ Returns:
+ net_codes (str): codes for this subclass
+ """
+ import os
+ import inspect
+
+ net_codes = str(inspect.getsource(MyNet))
+ net_codes = net_codes.replace('MyNet(object)', '%s(Network)' % net_name)
+ net_codes = net_codes.replace('"INPUTS_INFO"', inputs_info)
+
+ custom_layer_dir = os.path.dirname(os.path.abspath(__file__))
+ net_codes = net_codes.replace('CAFFE2FLUID_CUSTOM_LAYERS', custom_layer_dir)
+ return net_codes
+
+
+def generate_main_code(net_name):
+ """ generate a piece of code for 'main' function
+
+ Args:
+ @net_name (str): class name for this net
+
+ Returns:
+ main_codes (str): codes for this main function
+ """
+ import inspect
+
+ main_codes = str(inspect.getsource(main))
+ main_codes = main_codes.replace('MyNet', net_name)
+ return main_codes
+
+
+if __name__ == "__main__":
+ """ just for testing
+ """
+ print generate_net_code('Attribute', "{'data': [3, 277, 277]}")
+ print generate_main_code('Attribute')
diff --git a/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py b/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py
index ac5ecf1d4491efb5043502824514498f79ab4db0..258830bdac00af8fb9f2e83207730b404a04f7d5 100644
--- a/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py
+++ b/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py
@@ -1,5 +1,6 @@
-import math
+import sys
import os
+import math
import numpy as np
@@ -161,7 +162,8 @@ class Network(object):
output = fluid.layers.relu(x=input)
return output
- def pool(self, pool_type, input, k_h, k_w, s_h, s_w, name, padding):
+ def pool(self, pool_type, input, k_h, k_w, s_h, s_w, ceil_mode, padding,
+ name):
# Get the number of channels in the input
in_hw = input.shape[2:]
k_hw = [k_h, k_w]
@@ -173,17 +175,40 @@ class Network(object):
pool_size=k_hw,
pool_stride=s_hw,
pool_padding=padding,
- ceil_mode=True,
+ ceil_mode=ceil_mode,
pool_type=pool_type)
return output
@layer
- def max_pool(self, input, k_h, k_w, s_h, s_w, name, padding=[0, 0]):
- return self.pool('max', input, k_h, k_w, s_h, s_w, name, padding)
+ def max_pool(self,
+ input,
+ k_h,
+ k_w,
+ s_h,
+ s_w,
+ ceil_mode,
+ padding=[0, 0],
+ name=None):
+ return self.pool('max', input, k_h, k_w, s_h, s_w, ceil_mode, padding,
+ name)
+
+ @layer
+ def avg_pool(self,
+ input,
+ k_h,
+ k_w,
+ s_h,
+ s_w,
+ ceil_mode,
+ padding=[0, 0],
+ name=None):
+ return self.pool('avg', input, k_h, k_w, s_h, s_w, ceil_mode, padding,
+ name)
@layer
- def avg_pool(self, input, k_h, k_w, s_h, s_w, name, padding=[0, 0]):
- return self.pool('avg', input, k_h, k_w, s_h, s_w, name, padding)
+ def sigmoid(self, input, name):
+ fluid = import_fluid()
+ return fluid.layers.sigmoid(input)
@layer
def lrn(self, input, radius, alpha, beta, name, bias=1.0):
@@ -264,3 +289,16 @@ class Network(object):
output = fluid.layers.dropout(
input, dropout_prob=drop_prob, is_test=is_test, name=name)
return output
+
+ def custom_layer_factory(self):
+ """ get a custom layer maker provided by subclass
+ """
+ raise NotImplementedError(
+ '[custom_layer_factory] must be implemented by the subclass.')
+
+ @layer
+ def custom_layer(self, inputs, kind, name, *args, **kwargs):
+ """ make custom layer
+ """
+ layer_factory = self.custom_layer_factory()
+ return layer_factory(kind, inputs, name, *args, **kwargs)
diff --git a/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py b/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py
index 3697529971fa6ca01d1703375243d16f0a0c1edd..6aa3b38531f946e4656e05c52c69087f3b89aaf4 100644
--- a/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py
+++ b/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py
@@ -109,9 +109,17 @@ class TensorFlowMapper(NodeMapper):
# Stochastic pooling, for instance.
raise KaffeError('Unsupported pooling type.')
(kernel_params, padding) = self.get_kernel_params(node)
+ ceil_mode = getattr(node.layer.parameters, 'ceil_mode', True)
return TensorFlowNode(pool_op, kernel_params.kernel_h,
kernel_params.kernel_w, kernel_params.stride_h,
- kernel_params.stride_w, **padding)
+ kernel_params.stride_w, ceil_mode, **padding)
+
+ def map_sigmoid(self, node):
+ return TensorFlowNode('sigmoid')
+
+ def map_custom(self, node):
+ from .. import custom_layers
+ return custom_layers.make_node(TensorFlowNode, node.kind, node)
def map_inner_product(self, node):
#TODO: Axis
@@ -190,18 +198,10 @@ class TensorFlowEmitter(object):
codes.append(network_source + '\n')
return self.statement('\n'.join(codes))
- def emit_class_def(self, name):
- return self.statement('class %s(Network):' % (name))
-
def emit_setup_def(self):
return self.statement('def setup(self):')
- def emit_shape_def(self, input_nodes):
- self.outdent()
- func_def = self.statement('@classmethod')
- func_def += self.statement('def input_shapes(cls):')
- self.indent()
-
+ def get_inputs_info(self, input_nodes):
input_shapes = {}
for n in input_nodes:
name = n.name
@@ -210,42 +210,7 @@ class TensorFlowEmitter(object):
input_shapes[name] = ', '.join(shape)
input_shapes = ['"%s": [%s]' % (n, l) for n, l in input_shapes.items()]
shape_str = ','.join(input_shapes)
- func_def += self.statement('return {%s}' % (shape_str))
- return '\n\n' + func_def
-
- def emit_convert_def(self, input_nodes):
- codes = []
- inputs = {}
- codes.append('shapes = cls.input_shapes()')
- for n in input_nodes:
- name = n.name
- layer_var = name + '_layer'
- layer_def = '%s = fluid.layers.data(name="%s", shape=shapes["%s"],'\
- ' dtype="float32")' % (layer_var, name, name)
- #layer_var, layer_def = data_layer_def(n.name, n.output_shape)
- codes.append(layer_def)
- inputs[name] = layer_var
-
- input_dict = ','.join(['"%s": %s' % (n, l) for n, l in inputs.items()])
-
- codes.append('feed_data = {' + input_dict + '}')
- codes.append('net = cls(feed_data)')
-
- codes.append("place = fluid.CPUPlace()")
- codes.append("exe = fluid.Executor(place)")
- codes.append("exe.run(fluid.default_startup_program())")
- codes.append("net.load(data_path=npy_model, exe=exe, place=place)")
- codes.append(
- "fluid.io.save_persistables(executor=exe, dirname=fluid_path)")
-
- self.outdent()
- func_def = self.statement('@classmethod')
- func_def += self.statement('def convert(cls, npy_model, fluid_path):')
- self.indent()
- func_def += self.statement('fluid = import_fluid()')
- for l in codes:
- func_def += self.statement(l)
- return '\n' + func_def
+ return '{%s}' % (shape_str)
def emit_main_def(self, name):
if name is None:
@@ -254,13 +219,7 @@ class TensorFlowEmitter(object):
self.prefix = ''
main_def = self.statement('if __name__ == "__main__":')
self.indent()
- main_def += self.statement("#usage: python xxxnet.py xxx.npy ./model\n")
- main_def += self.statement("import sys")
- main_def += self.statement("npy_weight = sys.argv[1]")
- main_def += self.statement("fluid_model = sys.argv[2]")
- main_def += self.statement("%s.convert(npy_weight, fluid_model)" %
- (name))
- main_def += self.statement("exit(0)")
+ main_def += self.statement('exit(main())')
return '\n\n' + main_def
def emit_parents(self, chain):
@@ -275,10 +234,17 @@ class TensorFlowEmitter(object):
return self.statement('self.' + node.emit())
def emit(self, name, chains, input_nodes=None):
+ from ..net_template import generate_net_code
+ from ..net_template import generate_main_code
+
self.net_name = name
+ inputs_info = self.get_inputs_info(input_nodes)
+
s = self.emit_imports()
- s += self.emit_class_def(name)
+ s += generate_net_code(name, inputs_info) + '\n'
self.indent()
+
+ # define the net using api
s += self.emit_setup_def()
self.indent()
blocks = []
@@ -289,8 +255,9 @@ class TensorFlowEmitter(object):
b += self.emit_node(node)
blocks.append(b[:-1])
s = s + '\n\n'.join(blocks)
- s += self.emit_shape_def(input_nodes)
- s += self.emit_convert_def(input_nodes)
+
+ # define the main function
+ s += '\n\n\n' + generate_main_code(name)
s += self.emit_main_def(name)
return s
@@ -329,6 +296,7 @@ class Transformer(object):
# (Caffe's GoogLeNet implementation uses slashes)
NodeRenamer(lambda node: node.name.replace('/', '_'))
]
+
self.graph = graph.transformed(transformers)
# Display the graph
@@ -340,9 +308,6 @@ class Transformer(object):
transformers = [
# Reshape the parameters to TensorFlow's ordering
DataReshaper({
- # (c_o, c_i, h, w) -> (h, w, c_i, c_o) for TF
- NodeKind.Convolution: (0, 1, 2, 3),
-
# (c_o, c_i) -> (c_i, c_o)
NodeKind.InnerProduct: (1, 0)
}),
diff --git a/fluid/image_classification/caffe2fluid/kaffe/shapes.py b/fluid/image_classification/caffe2fluid/kaffe/shapes.py
index e8124730c66eaecb85f7aff58e08f6dc16668343..a2ce26362bb9afd659f8db7d678afeabd3efa6b5 100644
--- a/fluid/image_classification/caffe2fluid/kaffe/shapes.py
+++ b/fluid/image_classification/caffe2fluid/kaffe/shapes.py
@@ -3,8 +3,24 @@ from collections import namedtuple
from .errors import KaffeError
-TensorShape = namedtuple('TensorShape',
- ['batch_size', 'channels', 'height', 'width'])
+Tensor4DShape = namedtuple('Tensor4DShape',
+ ['batch_size', 'channels', 'height', 'width'])
+
+Tensor2DShape = namedtuple('Tensor2DShape', ['batch_size', 'data'])
+
+ScalarShape = namedtuple('ScalarShape', ['batch_size'])
+
+
+def make_tensor(batch_size, d1=None, d2=None, d3=None):
+ if d3 is not None:
+ return Tensor4DShape(batch_size, d1, d2, d3)
+ elif d1 is not None and d2 is None:
+ return Tensor2DShape(batch_size, d1)
+ elif d1 is None and d2 is None and d3 is None:
+ return ScalarShape(batch_size)
+ else:
+ raise NotImplementedError('invalid params for make_tensor %s' \
+ % (str((batch_size, d1, d2, d3))))
def get_filter_output_shape(i_h, i_w, params, round_func):
@@ -23,7 +39,7 @@ def get_strided_kernel_output_shape(node, round_func):
params = node.layer.parameters
has_c_o = hasattr(params, 'num_output')
c = params.num_output if has_c_o else input_shape.channels
- return TensorShape(input_shape.batch_size, c, o_h, o_w)
+ return make_tensor(input_shape.batch_size, c, o_h, o_w)
def shape_not_implemented(node):
@@ -36,7 +52,7 @@ def shape_identity(node):
def shape_scalar(node):
- return TensorShape(1, 1, 1, 1)
+ return make_tensor(1, 1, 1, 1)
def shape_data(node):
@@ -59,7 +75,7 @@ def shape_data(node):
def shape_mem_data(node):
params = node.parameters
- return TensorShape(params.batch_size, params.channels, params.height,
+ return make_tensor(params.batch_size, params.channels, params.height,
params.width)
@@ -79,10 +95,15 @@ def shape_convolution(node):
def shape_pool(node):
- return get_strided_kernel_output_shape(node, math.ceil)
+ ceil_mode = getattr(node.layer.parameters, 'ceil_mode', True)
+ if ceil_mode is True:
+ method = math.ceil
+ else:
+ method = math.floor
+
+ return get_strided_kernel_output_shape(node, method)
def shape_inner_product(node):
input_shape = node.get_only_parent().output_shape
- return TensorShape(input_shape.batch_size, node.layer.parameters.num_output,
- 1, 1)
+ return make_tensor(input_shape.batch_size, node.layer.parameters.num_output)
diff --git a/fluid/image_classification/caffe2fluid/kaffe/transformers.py b/fluid/image_classification/caffe2fluid/kaffe/transformers.py
index 9d300ca9c90672c3f3a3dbf7a14e48db6bb48f70..6d98703da3313cf466eb43c2adc49c0e0640a8de 100644
--- a/fluid/image_classification/caffe2fluid/kaffe/transformers.py
+++ b/fluid/image_classification/caffe2fluid/kaffe/transformers.py
@@ -66,12 +66,14 @@ class DataInjector(object):
def adjust_parameters(self, node, data):
if not self.did_use_pb:
return data
+
# When using the protobuf-backend, each parameter initially has four dimensions.
# In certain cases (like FC layers), we want to eliminate the singleton dimensions.
# This implementation takes care of the common cases. However, it does leave the
# potential for future issues.
# The Caffe-backend does not suffer from this problem.
data = list(data)
+
squeeze_indices = [1] # Squeeze biases.
if node.kind == NodeKind.InnerProduct:
squeeze_indices.append(0) # Squeeze FC.
@@ -80,8 +82,22 @@ class DataInjector(object):
if idx >= len(data):
continue
- shape_old = data[idx].shape
- data[idx] = np.squeeze(data[idx])
+ d = data[idx]
+ assert len(
+ d.shape
+ ) == 4, 'invalid shape[%s] from caffe when adjust_parameters' % (
+ str(d.shape))
+
+ shape_old = d.shape
+ sq_axis = None
+ if idx == 0:
+ sq_axis = (0, 1)
+ elif idx == 1:
+ sq_axis = (0, 1, 2)
+ else:
+ continue
+
+ data[idx] = np.squeeze(d, axis=sq_axis)
shape_new = data[idx].shape
if len(shape_old) != shape_new:
debug('squeeze idx:%d, with kind:%s,name:%s' % \
@@ -113,7 +129,10 @@ class DataReshaper(object):
try:
parent = node.get_only_parent()
s = parent.output_shape
- return s.height > 1 or s.width > 1
+ if len(s) == 4:
+ return s.height > 1 or s.width > 1
+ else:
+ return False
except KaffeError:
return False
@@ -121,25 +140,26 @@ class DataReshaper(object):
try:
return self.mapping[node_kind]
except KeyError:
- raise
- #raise KaffeError('Ordering not found for node kind: {}'.format(node_kind))
+ raise KaffeError('Ordering not found for node kind: {}'.format(
+ node_kind))
def __call__(self, graph):
for node in graph.nodes:
if node.data is None:
continue
+
if node.kind not in self.reshaped_node_types:
# Check for 2+ dimensional data
if any(len(tensor.shape) > 1 for tensor in node.data):
notice('parmaters not reshaped for node: {}'.format(node))
continue
+
transpose_order = self.map(node.kind)
weights = node.data[0]
- if (node.kind == NodeKind.InnerProduct
- ) and self.has_spatial_parent(node):
+ if node.kind == NodeKind.InnerProduct:
# The FC layer connected to the spatial layer needs to be
# re-wired to match the new spatial ordering.
- in_shape = node.get_only_parent().output_shape
+ #in_shape = node.get_only_parent().output_shape
fc_shape = weights.shape
output_channels = fc_shape[0]
weights = weights.reshape((output_channels, -1))
@@ -178,7 +198,8 @@ class SubNodeFuser(object):
continue
# Rewrite the fused node's children to its parent.
for child in node.children:
- child.parents.remove(node)
+ pos = child.parents.index(node)
+ child.parents[pos] = parent
parent.add_child(child)
# Disconnect the fused node from the graph.
parent.children.remove(node)
diff --git a/fluid/image_classification/caffe2fluid/proto/compile.sh b/fluid/image_classification/caffe2fluid/proto/compile.sh
old mode 100644
new mode 100755
diff --git a/fluid/image_classification/train.py b/fluid/image_classification/train.py
index f402c87d49862fd844d8cf36c6eb52f3e21895b3..6244e520900b58914b847c4acb451beb252efd30 100644
--- a/fluid/image_classification/train.py
+++ b/fluid/image_classification/train.py
@@ -18,17 +18,19 @@ add_arg('batch_size', int, 256, "Minibatch size.")
add_arg('num_layers', int, 50, "How many layers for SE-ResNeXt model.")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('parallel_exe', bool, True, "Whether to use ParallelExecutor to train or not.")
-
-def train_paralle_do(args,
- learning_rate,
- batch_size,
- num_passes,
- init_model=None,
- model_save_dir='model',
- parallel=True,
- use_nccl=True,
- lr_strategy=None,
- layers=50):
+# yapf: enable
+
+
+def train_parallel_do(args,
+ learning_rate,
+ batch_size,
+ num_passes,
+ init_model=None,
+ model_save_dir='model',
+ parallel=True,
+ use_nccl=True,
+ lr_strategy=None,
+ layers=50):
class_dim = 1000
image_shape = [3, 224, 224]
@@ -62,6 +64,8 @@ def train_paralle_do(args,
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
+ inference_program = fluid.default_main_program().clone(for_test=True)
+
if lr_strategy is None:
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
@@ -76,12 +80,9 @@ def train_paralle_do(args,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
- inference_program = fluid.default_main_program().clone(for_test=True)
-
opts = optimizer.minimize(avg_cost)
if args.with_mem_opt:
fluid.memory_optimize(fluid.default_main_program())
- fluid.memory_optimize(inference_program)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
@@ -154,6 +155,7 @@ def train_paralle_do(args,
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path)
+
def train_parallel_exe(args,
learning_rate,
batch_size,
@@ -195,7 +197,6 @@ def train_parallel_exe(args,
if args.with_mem_opt:
fluid.memory_optimize(fluid.default_main_program())
- fluid.memory_optimize(test_program)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
@@ -210,9 +211,7 @@ def train_parallel_exe(args,
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
test_exe = fluid.ParallelExecutor(
- use_cuda=True,
- main_program=test_program,
- share_vars_from=train_exe)
+ use_cuda=True, main_program=test_program, share_vars_from=train_exe)
fetch_list = [avg_cost.name, acc_top1.name, acc_top5.name]
@@ -221,9 +220,8 @@ def train_parallel_exe(args,
test_info = [[], [], []]
for batch_id, data in enumerate(train_reader()):
t1 = time.time()
- loss, acc1, acc5 = train_exe.run(
- fetch_list,
- feed_dict=feeder.feed(data))
+ loss, acc1, acc5 = train_exe.run(fetch_list,
+ feed_dict=feeder.feed(data))
t2 = time.time()
period = t2 - t1
loss = np.mean(np.array(loss))
@@ -245,9 +243,8 @@ def train_parallel_exe(args,
train_acc5 = np.array(train_info[2]).mean()
for data in test_reader():
t1 = time.time()
- loss, acc1, acc5 = test_exe.run(
- fetch_list,
- feed_dict=feeder.feed(data))
+ loss, acc1, acc5 = test_exe.run(fetch_list,
+ feed_dict=feeder.feed(data))
t2 = time.time()
period = t2 - t1
loss = np.mean(np.array(loss))
@@ -281,8 +278,6 @@ def train_parallel_exe(args,
fluid.io.save_persistables(exe, model_path)
-
-
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
@@ -300,12 +295,13 @@ if __name__ == '__main__':
# layers: 50, 152
layers = args.num_layers
method = train_parallel_exe if args.parallel_exe else train_parallel_do
- method(args,
- learning_rate=0.1,
- batch_size=batch_size,
- num_passes=120,
- init_model=None,
- parallel=True,
- use_nccl=True,
- lr_strategy=lr_strategy,
- layers=layers)
+ method(
+ args,
+ learning_rate=0.1,
+ batch_size=batch_size,
+ num_passes=120,
+ init_model=None,
+ parallel=True,
+ use_nccl=True,
+ lr_strategy=lr_strategy,
+ layers=layers)
diff --git a/fluid/language_model/README.md b/fluid/language_model/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..91ce2d7f58085b56da2ac2dec03af2a05985ab8f
--- /dev/null
+++ b/fluid/language_model/README.md
@@ -0,0 +1,148 @@
+# 语言模型
+
+以下是本例的简要目录结构及说明:
+
+```text
+.
+├── README.md # 文档
+├── train.py # 训练脚本
+├── infer.py # 预测脚本
+└── utils.py # 通用函数
+```
+
+
+## 简介
+
+循环神经网络语言模型的介绍可以参阅论文[Recurrent Neural Network Regularization](https://arxiv.org/abs/1409.2329),在本例中,我们实现了GRU-RNN语言模型。
+
+## 训练
+
+运行命令 `python train.py` 开始训练模型。
+```python
+python train.py
+```
+
+当前支持的参数可参见[train.py](./train.py) `train_net` 函数
+```python
+vocab, train_reader, test_reader = utils.prepare_data(
+ batch_size=20, # batch size
+ buffer_size=1000, # buffer size, default value is OK
+ word_freq_threshold=0) # vocabulary related parameter, and words with frequency below this value will be filtered
+
+train(train_reader=train_reader,
+ vocab=vocab,
+ network=network,
+ hid_size=200, # embedding and hidden size
+ base_lr=1.0, # base learning rate
+ batch_size=20, # batch size, the same as that in prepare_data
+ pass_num=12, # the number of passes for training
+ use_cuda=True, # whether to use GPU card
+ parallel=False, # whether to be parallel
+ model_dir="model", # directory to save model
+ init_low_bound=-0.1, # uniform parameter initialization lower bound
+ init_high_bound=0.1) # uniform parameter initialization upper bound
+```
+
+## 自定义网络结构
+
+可在[train.py](./train.py) `network` 函数中调整网络结构,当前的网络结构如下:
+```python
+emb = fluid.layers.embedding(input=src, size=[vocab_size, hid_size],
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Uniform(low=init_low_bound, high=init_high_bound),
+ learning_rate=emb_lr_x),
+ is_sparse=True)
+
+fc0 = fluid.layers.fc(input=emb, size=hid_size * 3,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Uniform(low=init_low_bound, high=init_high_bound),
+ learning_rate=gru_lr_x))
+gru_h0 = fluid.layers.dynamic_gru(input=fc0, size=hid_size,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Uniform(low=init_low_bound, high=init_high_bound),
+ learning_rate=gru_lr_x))
+
+fc = fluid.layers.fc(input=gru_h0, size=vocab_size, act='softmax',
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Uniform(low=init_low_bound, high=init_high_bound),
+ learning_rate=fc_lr_x))
+
+cost = fluid.layers.cross_entropy(input=fc, label=dst)
+```
+
+## 训练结果示例
+
+我们在Tesla K40m单GPU卡上训练的日志如下所示
+```text
+epoch_1 start
+step:100 ppl:771.053
+step:200 ppl:449.597
+step:300 ppl:642.654
+step:400 ppl:458.128
+step:500 ppl:510.912
+step:600 ppl:451.545
+step:700 ppl:364.404
+step:800 ppl:324.272
+step:900 ppl:360.797
+step:1000 ppl:275.761
+step:1100 ppl:294.599
+step:1200 ppl:335.877
+step:1300 ppl:185.262
+step:1400 ppl:241.744
+step:1500 ppl:211.507
+step:1600 ppl:233.431
+step:1700 ppl:298.767
+step:1800 ppl:203.403
+step:1900 ppl:158.828
+step:2000 ppl:171.148
+step:2100 ppl:280.884
+epoch:1 num_steps:2104 time_cost(s):47.478780
+model saved in model/epoch_1
+epoch_2 start
+step:100 ppl:238.099
+step:200 ppl:136.527
+step:300 ppl:204.184
+step:400 ppl:252.886
+step:500 ppl:177.377
+step:600 ppl:197.688
+step:700 ppl:131.650
+step:800 ppl:223.906
+step:900 ppl:144.785
+step:1000 ppl:176.286
+step:1100 ppl:148.158
+step:1200 ppl:203.581
+step:1300 ppl:168.208
+step:1400 ppl:159.412
+step:1500 ppl:114.032
+step:1600 ppl:157.985
+step:1700 ppl:147.743
+step:1800 ppl:88.676
+step:1900 ppl:141.962
+step:2000 ppl:106.087
+step:2100 ppl:122.709
+epoch:2 num_steps:2104 time_cost(s):47.583789
+model saved in model/epoch_2
+...
+```
+
+## 预测
+运行命令 `python infer.py model_dir start_epoch last_epoch(inclusive)` 开始预测,其中,start_epoch指定开始预测的轮次,last_epoch指定结束的轮次,例如
+```python
+python infer.py model 1 12 # prediction from epoch 1 to epoch 12
+```
+
+## 预测结果示例
+```text
+model:model/epoch_1 ppl:254.540 time_cost(s):3.29
+model:model/epoch_2 ppl:177.671 time_cost(s):3.27
+model:model/epoch_3 ppl:156.251 time_cost(s):3.27
+model:model/epoch_4 ppl:139.036 time_cost(s):3.27
+model:model/epoch_5 ppl:132.661 time_cost(s):3.27
+model:model/epoch_6 ppl:130.092 time_cost(s):3.28
+model:model/epoch_7 ppl:128.751 time_cost(s):3.27
+model:model/epoch_8 ppl:125.411 time_cost(s):3.27
+model:model/epoch_9 ppl:124.604 time_cost(s):3.28
+model:model/epoch_10 ppl:124.754 time_cost(s):3.29
+model:model/epoch_11 ppl:125.421 time_cost(s):3.27
+model:model/epoch_12 ppl:125.676 time_cost(s):3.27
+```
diff --git a/fluid/language_model/infer.py b/fluid/language_model/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a183d54852dc9d56b76968a5f450479a43325304
--- /dev/null
+++ b/fluid/language_model/infer.py
@@ -0,0 +1,65 @@
+import sys
+import time
+import math
+import unittest
+import contextlib
+import numpy as np
+
+import paddle.fluid as fluid
+import paddle.v2 as paddle
+
+import utils
+
+
+def infer(test_reader, use_cuda, model_path):
+ """ inference function """
+ place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
+ exe = fluid.Executor(place)
+
+ with fluid.scope_guard(fluid.core.Scope()):
+ infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model(
+ model_path, exe)
+
+ accum_cost = 0.0
+ accum_words = 0
+ t0 = time.time()
+ for data in test_reader():
+ src_wordseq = utils.to_lodtensor(map(lambda x: x[0], data), place)
+ dst_wordseq = utils.to_lodtensor(map(lambda x: x[1], data), place)
+ avg_cost = exe.run(
+ infer_program,
+ feed={"src_wordseq": src_wordseq,
+ "dst_wordseq": dst_wordseq},
+ fetch_list=fetch_vars)
+
+ nwords = src_wordseq.lod()[0][-1]
+
+ cost = np.array(avg_cost) * nwords
+ accum_cost += cost
+ accum_words += nwords
+
+ ppl = math.exp(accum_cost / accum_words)
+ t1 = time.time()
+ print("model:%s ppl:%.3f time_cost(s):%.2f" %
+ (model_path, ppl, t1 - t0))
+
+
+if __name__ == "__main__":
+ if len(sys.argv) != 4:
+ print("Usage: %s model_dir start_epoch last_epoch(inclusive)")
+ exit(0)
+
+ model_dir = sys.argv[1]
+ try:
+ start_index = int(sys.argv[2])
+ last_index = int(sys.argv[3])
+ except:
+ print("Usage: %s model_dir start_epoch last_epoch(inclusive)")
+ exit(-1)
+
+ vocab, train_reader, test_reader = utils.prepare_data(
+ batch_size=20, buffer_size=1000, word_freq_threshold=0)
+
+ for epoch in xrange(start_index, last_index + 1):
+ epoch_path = model_dir + "/epoch_" + str(epoch)
+ infer(test_reader=test_reader, use_cuda=True, model_path=epoch_path)
diff --git a/fluid/language_model/train.py b/fluid/language_model/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..59fc3a987746af7aec9b61b5c817400b6b6546d0
--- /dev/null
+++ b/fluid/language_model/train.py
@@ -0,0 +1,162 @@
+import sys
+import time
+
+import numpy as np
+import math
+
+import paddle.fluid as fluid
+import paddle.v2 as paddle
+
+import utils
+
+
+def network(src, dst, vocab_size, hid_size, init_low_bound, init_high_bound):
+ """ network definition """
+ emb_lr_x = 10.0
+ gru_lr_x = 1.0
+ fc_lr_x = 1.0
+ emb = fluid.layers.embedding(
+ input=src,
+ size=[vocab_size, hid_size],
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Uniform(
+ low=init_low_bound, high=init_high_bound),
+ learning_rate=emb_lr_x),
+ is_sparse=True)
+
+ fc0 = fluid.layers.fc(input=emb,
+ size=hid_size * 3,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Uniform(
+ low=init_low_bound, high=init_high_bound),
+ learning_rate=gru_lr_x))
+ gru_h0 = fluid.layers.dynamic_gru(
+ input=fc0,
+ size=hid_size,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Uniform(
+ low=init_low_bound, high=init_high_bound),
+ learning_rate=gru_lr_x))
+
+ fc = fluid.layers.fc(input=gru_h0,
+ size=vocab_size,
+ act='softmax',
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Uniform(
+ low=init_low_bound, high=init_high_bound),
+ learning_rate=fc_lr_x))
+
+ cost = fluid.layers.cross_entropy(input=fc, label=dst)
+ return cost
+
+
+def train(train_reader,
+ vocab,
+ network,
+ hid_size,
+ base_lr,
+ batch_size,
+ pass_num,
+ use_cuda,
+ parallel,
+ model_dir,
+ init_low_bound=-0.04,
+ init_high_bound=0.04):
+ """ train network """
+ vocab_size = len(vocab)
+
+ src_wordseq = fluid.layers.data(
+ name="src_wordseq", shape=[1], dtype="int64", lod_level=1)
+ dst_wordseq = fluid.layers.data(
+ name="dst_wordseq", shape=[1], dtype="int64", lod_level=1)
+
+ avg_cost = None
+ if not parallel:
+ cost = network(src_wordseq, dst_wordseq, vocab_size, hid_size,
+ init_low_bound, init_high_bound)
+ avg_cost = fluid.layers.mean(x=cost)
+ else:
+ places = fluid.layers.get_places()
+ pd = fluid.layers.ParallelDo(places)
+ with pd.do():
+ cost = network(
+ pd.read_input(src_wordseq),
+ pd.read_input(dst_wordseq), vocab_size, hid_size,
+ init_low_bound, init_high_bound)
+ pd.write_output(cost)
+
+ cost = pd()
+ avg_cost = fluid.layers.mean(x=cost)
+
+ sgd_optimizer = fluid.optimizer.SGD(
+ learning_rate=fluid.layers.exponential_decay(
+ learning_rate=base_lr,
+ decay_steps=2100 * 4,
+ decay_rate=0.5,
+ staircase=True))
+ sgd_optimizer.minimize(avg_cost)
+
+ place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
+ exe = fluid.Executor(place)
+
+ exe.run(fluid.default_startup_program())
+ total_time = 0.0
+ for pass_idx in xrange(pass_num):
+ epoch_idx = pass_idx + 1
+ print "epoch_%d start" % epoch_idx
+
+ t0 = time.time()
+ i = 0
+ for data in train_reader():
+ i += 1
+ lod_src_wordseq = utils.to_lodtensor(
+ map(lambda x: x[0], data), place)
+ lod_dst_wordseq = utils.to_lodtensor(
+ map(lambda x: x[1], data), place)
+ ret_avg_cost = exe.run(fluid.default_main_program(),
+ feed={
+ "src_wordseq": lod_src_wordseq,
+ "dst_wordseq": lod_dst_wordseq
+ },
+ fetch_list=[avg_cost],
+ use_program_cache=True)
+ avg_ppl = math.exp(ret_avg_cost[0])
+ if i % 100 == 0:
+ print "step:%d ppl:%.3f" % (i, avg_ppl)
+
+ t1 = time.time()
+ total_time += t1 - t0
+ print "epoch:%d num_steps:%d time_cost(s):%f" % (epoch_idx, i,
+ total_time / epoch_idx)
+
+ save_dir = "%s/epoch_%d" % (model_dir, epoch_idx)
+ feed_var_names = ["src_wordseq", "dst_wordseq"]
+ fetch_vars = [avg_cost]
+ fluid.io.save_inference_model(save_dir, feed_var_names, fetch_vars, exe)
+ print("model saved in %s" % save_dir)
+
+ print("finish training")
+
+
+def train_net():
+ """ do training """
+ batch_size = 20
+ vocab, train_reader, test_reader = utils.prepare_data(
+ batch_size=batch_size, buffer_size=1000, word_freq_threshold=0)
+ train(
+ train_reader=train_reader,
+ vocab=vocab,
+ network=network,
+ hid_size=200,
+ base_lr=1.0,
+ batch_size=batch_size,
+ pass_num=12,
+ use_cuda=True,
+ parallel=False,
+ model_dir="model",
+ init_low_bound=-0.1,
+ init_high_bound=0.1)
+
+
+if __name__ == "__main__":
+ train_net()
diff --git a/fluid/language_model/utils.py b/fluid/language_model/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5909046176586556a2aedba5dd5d12810b3ea8d
--- /dev/null
+++ b/fluid/language_model/utils.py
@@ -0,0 +1,40 @@
+import sys
+import time
+import numpy as np
+
+import paddle.fluid as fluid
+import paddle.v2 as paddle
+
+
+def to_lodtensor(data, place):
+ """ convert to LODtensor """
+ seq_lens = [len(seq) for seq in data]
+ cur_len = 0
+ lod = [cur_len]
+ for l in seq_lens:
+ cur_len += l
+ lod.append(cur_len)
+ flattened_data = np.concatenate(data, axis=0).astype("int64")
+ flattened_data = flattened_data.reshape([len(flattened_data), 1])
+ res = fluid.LoDTensor()
+ res.set(flattened_data, place)
+ res.set_lod([lod])
+ return res
+
+
+def prepare_data(batch_size, buffer_size=1000, word_freq_threshold=0):
+ """ prepare the English Pann Treebank (PTB) data """
+ vocab = paddle.dataset.imikolov.build_dict(word_freq_threshold)
+ train_reader = paddle.batch(
+ paddle.reader.shuffle(
+ paddle.dataset.imikolov.train(
+ vocab,
+ buffer_size,
+ data_type=paddle.dataset.imikolov.DataType.SEQ),
+ buf_size=buffer_size),
+ batch_size)
+ test_reader = paddle.batch(
+ paddle.dataset.imikolov.test(
+ vocab, buffer_size, data_type=paddle.dataset.imikolov.DataType.SEQ),
+ batch_size)
+ return vocab, train_reader, test_reader
diff --git a/fluid/object_detection/.gitignore b/fluid/object_detection/.gitignore
index 3321aa105e8c63b5ba915782fd69bc90debbf56c..b68dc43d08fbc2415a7c099112350ca940d6519c 100644
--- a/fluid/object_detection/.gitignore
+++ b/fluid/object_detection/.gitignore
@@ -6,3 +6,4 @@ pretrained/ssd_mobilenet_v1_coco
pretrained/mobilenet_v1_imagenet.tar.gz
pretrained/mobilenet_v1_imagenet
log*
+*.log
diff --git a/fluid/object_detection/eval.py b/fluid/object_detection/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..84fbcf82c9862ed0c9e74a4ba5bacd50372ad7ab
--- /dev/null
+++ b/fluid/object_detection/eval.py
@@ -0,0 +1,106 @@
+import os
+import time
+import numpy as np
+import argparse
+import functools
+
+import paddle
+import paddle.fluid as fluid
+import reader
+from mobilenet_ssd import mobile_net
+from utility import add_arguments, print_arguments
+
+parser = argparse.ArgumentParser(description=__doc__)
+add_arg = functools.partial(add_arguments, argparser=parser)
+# yapf: disable
+add_arg('dataset', str, 'pascalvoc', "coco or pascalvoc.")
+add_arg('batch_size', int, 32, "Minibatch size.")
+add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
+add_arg('data_dir', str, '', "The data root path.")
+add_arg('test_list', str, '', "The testing data lists.")
+add_arg('label_file', str, '', "The label file, which save the real name and is only used for Pascal VOC.")
+add_arg('model_dir', str, '', "The model path.")
+add_arg('ap_version', str, '11point', "11point or integral")
+add_arg('resize_h', int, 300, "The resized image height.")
+add_arg('resize_w', int, 300, "The resized image width.")
+add_arg('mean_value_B', float, 127.5, "mean value for B channel which will be subtracted") #123.68
+add_arg('mean_value_G', float, 127.5, "mean value for G channel which will be subtracted") #116.78
+add_arg('mean_value_R', float, 127.5, "mean value for R channel which will be subtracted") #103.94
+# yapf: enable
+
+
+def eval(args, data_args, test_list, batch_size, model_dir=None):
+ image_shape = [3, data_args.resize_h, data_args.resize_w]
+ if data_args.dataset == 'coco':
+ num_classes = 81
+ elif data_args.dataset == 'pascalvoc':
+ num_classes = 21
+
+ image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
+ gt_box = fluid.layers.data(
+ name='gt_box', shape=[4], dtype='float32', lod_level=1)
+ gt_label = fluid.layers.data(
+ name='gt_label', shape=[1], dtype='int32', lod_level=1)
+ difficult = fluid.layers.data(
+ name='gt_difficult', shape=[1], dtype='int32', lod_level=1)
+
+ locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
+ nmsed_out = fluid.layers.detection_output(
+ locs, confs, box, box_var, nms_threshold=0.45)
+ loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box, box_var)
+ loss = fluid.layers.reduce_sum(loss)
+
+ test_program = fluid.default_main_program().clone(for_test=True)
+ with fluid.program_guard(test_program):
+ map_eval = fluid.evaluator.DetectionMAP(
+ nmsed_out,
+ gt_label,
+ gt_box,
+ difficult,
+ num_classes,
+ overlap_threshold=0.5,
+ evaluate_difficult=False,
+ ap_version=args.ap_version)
+
+ place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
+ exe = fluid.Executor(place)
+
+ if model_dir:
+
+ def if_exist(var):
+ return os.path.exists(os.path.join(model_dir, var.name))
+
+ fluid.io.load_vars(exe, model_dir, predicate=if_exist)
+
+ test_reader = paddle.batch(
+ reader.test(data_args, test_list), batch_size=batch_size)
+ feeder = fluid.DataFeeder(
+ place=place, feed_list=[image, gt_box, gt_label, difficult])
+
+ _, accum_map = map_eval.get_map_var()
+ map_eval.reset(exe)
+ for idx, data in enumerate(test_reader()):
+ test_map = exe.run(test_program,
+ feed=feeder.feed(data),
+ fetch_list=[accum_map])
+ if idx % 50 == 0:
+ print("Batch {0}, map {1}".format(idx, test_map[0]))
+ print("Test model {0}, map {1}".format(model_dir, test_map[0]))
+
+
+if __name__ == '__main__':
+ args = parser.parse_args()
+ print_arguments(args)
+ data_args = reader.Settings(
+ dataset=args.dataset,
+ data_dir=args.data_dir,
+ label_file=args.label_file,
+ resize_h=args.resize_h,
+ resize_w=args.resize_w,
+ mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R])
+ eval(
+ args,
+ test_list=args.test_list,
+ data_args=data_args,
+ batch_size=args.batch_size,
+ model_dir=args.model_dir)
diff --git a/fluid/object_detection/image_util.py b/fluid/object_detection/image_util.py
index b8464cfe8745b33249a8da3427689aec6904cd99..4ce53048b9f8117e937411392531eeb4090fcb67 100644
--- a/fluid/object_detection/image_util.py
+++ b/fluid/object_detection/image_util.py
@@ -216,7 +216,7 @@ def distort_image(img, settings):
def expand_image(img, bbox_labels, img_width, img_height, settings):
prob = random.uniform(0, 1)
if prob < settings._expand_prob:
- if _expand_max_ratio - 1 >= 0.01:
+ if settings._expand_max_ratio - 1 >= 0.01:
expand_ratio = random.uniform(1, settings._expand_max_ratio)
height = int(img_height * expand_ratio)
width = int(img_width * expand_ratio)
diff --git a/fluid/object_detection/reader.py b/fluid/object_detection/reader.py
index 43c54b4c4f0ed84c35ba98f84a76cf390fd47afd..78efcc4a517001023c72c9d82c6972d60e830c6c 100644
--- a/fluid/object_detection/reader.py
+++ b/fluid/object_detection/reader.py
@@ -25,8 +25,16 @@ import copy
class Settings(object):
- def __init__(self, dataset, toy, data_dir, label_file, resize_h, resize_w,
- mean_value, apply_distort, apply_expand):
+ def __init__(self,
+ dataset=None,
+ data_dir=None,
+ label_file=None,
+ resize_h=300,
+ resize_w=300,
+ mean_value=[127.5, 127.5, 127.5],
+ apply_distort=True,
+ apply_expand=True,
+ toy=0):
self._dataset = dataset
self._toy = toy
self._data_dir = data_dir
@@ -94,169 +102,168 @@ class Settings(object):
return self._img_mean
-def _reader_creator(settings, file_list, mode, shuffle):
+def preprocess(img, bbox_labels, mode, settings):
+ img_width, img_height = img.size
+ sampled_labels = bbox_labels
+ if mode == 'train':
+ if settings._apply_distort:
+ img = image_util.distort_image(img, settings)
+ if settings._apply_expand:
+ img, bbox_labels, img_width, img_height = image_util.expand_image(
+ img, bbox_labels, img_width, img_height, settings)
+ # sampling
+ batch_sampler = []
+ # hard-code here
+ batch_sampler.append(
+ image_util.sampler(1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0))
+ batch_sampler.append(
+ image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 0.0))
+ batch_sampler.append(
+ image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 0.0))
+ batch_sampler.append(
+ image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 0.0))
+ batch_sampler.append(
+ image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 0.0))
+ batch_sampler.append(
+ image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 0.0))
+ batch_sampler.append(
+ image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0))
+ sampled_bbox = image_util.generate_batch_samples(batch_sampler,
+ bbox_labels)
+
+ img = np.array(img)
+ if len(sampled_bbox) > 0:
+ idx = int(random.uniform(0, len(sampled_bbox)))
+ img, sampled_labels = image_util.crop_image(
+ img, bbox_labels, sampled_bbox[idx], img_width, img_height)
+
+ img = Image.fromarray(img)
+ img = img.resize((settings.resize_w, settings.resize_h), Image.ANTIALIAS)
+ img = np.array(img)
+
+ if mode == 'train':
+ mirror = int(random.uniform(0, 2))
+ if mirror == 1:
+ img = img[:, ::-1, :]
+ for i in xrange(len(sampled_labels)):
+ tmp = sampled_labels[i][1]
+ sampled_labels[i][1] = 1 - sampled_labels[i][3]
+ sampled_labels[i][3] = 1 - tmp
+ # HWC to CHW
+ if len(img.shape) == 3:
+ img = np.swapaxes(img, 1, 2)
+ img = np.swapaxes(img, 1, 0)
+ # RBG to BGR
+ img = img[[2, 1, 0], :, :]
+ img = img.astype('float32')
+ img -= settings.img_mean
+ img = img * 0.007843
+ return img, sampled_labels
+
+
+def coco(settings, file_list, mode, shuffle):
+ # cocoapi
+ from pycocotools.coco import COCO
+ from pycocotools.cocoeval import COCOeval
+
+ coco = COCO(file_list)
+ image_ids = coco.getImgIds()
+ images = coco.loadImgs(image_ids)
+ category_ids = coco.getCatIds()
+ category_names = [item['name'] for item in coco.loadCats(category_ids)]
+
+ if not settings.toy == 0:
+ images = images[:settings.toy] if len(images) > settings.toy else images
+ print("{} on {} with {} images".format(mode, settings.dataset, len(images)))
+
def reader():
- if settings.dataset == 'coco':
- # cocoapi
- from pycocotools.coco import COCO
- from pycocotools.cocoeval import COCOeval
-
- coco = COCO(file_list)
- image_ids = coco.getImgIds()
- images = coco.loadImgs(image_ids)
- category_ids = coco.getCatIds()
- category_names = [
- item['name'] for item in coco.loadCats(category_ids)
- ]
- elif settings.dataset == 'pascalvoc':
- flist = open(file_list)
- images = [line.strip() for line in flist]
-
- if not settings.toy == 0:
- images = images[:settings.toy] if len(
- images) > settings.toy else images
- print("{} on {} with {} images".format(mode, settings.dataset,
- len(images)))
-
- if shuffle:
+ if mode == 'train' and shuffle:
random.shuffle(images)
-
for image in images:
- if settings.dataset == 'coco':
- image_name = image['file_name']
- image_path = os.path.join(settings.data_dir, image_name)
- elif settings.dataset == 'pascalvoc':
- if mode == 'train' or mode == 'test':
- image_path, label_path = image.split()
- image_path = os.path.join(settings.data_dir, image_path)
- label_path = os.path.join(settings.data_dir, label_path)
- elif mode == 'infer':
- image_path = os.path.join(settings.data_dir, image)
-
- img = Image.open(image_path)
- if img.mode == 'L':
- img = img.convert('RGB')
- img_width, img_height = img.size
-
- if mode == 'train' or mode == 'test':
- if settings.dataset == 'coco':
- # layout: category_id | xmin | ymin | xmax | ymax | iscrowd | origin_coco_bbox | segmentation | area | image_id | annotation_id
- bbox_labels = []
- annIds = coco.getAnnIds(imgIds=image['id'])
- anns = coco.loadAnns(annIds)
- for ann in anns:
- bbox_sample = []
- # start from 1, leave 0 to background
- bbox_sample.append(
- float(category_ids.index(ann['category_id'])) + 1)
- bbox = ann['bbox']
- xmin, ymin, w, h = bbox
- xmax = xmin + w
- ymax = ymin + h
- bbox_sample.append(float(xmin) / img_width)
- bbox_sample.append(float(ymin) / img_height)
- bbox_sample.append(float(xmax) / img_width)
- bbox_sample.append(float(ymax) / img_height)
- bbox_sample.append(float(ann['iscrowd']))
- #bbox_sample.append(ann['bbox'])
- #bbox_sample.append(ann['segmentation'])
- #bbox_sample.append(ann['area'])
- #bbox_sample.append(ann['image_id'])
- #bbox_sample.append(ann['id'])
- bbox_labels.append(bbox_sample)
- elif settings.dataset == 'pascalvoc':
- # layout: label | xmin | ymin | xmax | ymax | difficult
- bbox_labels = []
- root = xml.etree.ElementTree.parse(label_path).getroot()
- for object in root.findall('object'):
- bbox_sample = []
- # start from 1
- bbox_sample.append(
- float(
- settings.label_list.index(
- object.find('name').text)))
- bbox = object.find('bndbox')
- difficult = float(object.find('difficult').text)
- bbox_sample.append(
- float(bbox.find('xmin').text) / img_width)
- bbox_sample.append(
- float(bbox.find('ymin').text) / img_height)
- bbox_sample.append(
- float(bbox.find('xmax').text) / img_width)
- bbox_sample.append(
- float(bbox.find('ymax').text) / img_height)
- bbox_sample.append(difficult)
- bbox_labels.append(bbox_sample)
-
- sample_labels = bbox_labels
- if mode == 'train':
- if settings._apply_distort:
- img = image_util.distort_image(img, settings)
- if settings._apply_expand:
- img, bbox_labels, img_width, img_height = image_util.expand_image(
- img, bbox_labels, img_width, img_height, settings)
- batch_sampler = []
- # hard-code here
- batch_sampler.append(
- image_util.sampler(1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0))
- batch_sampler.append(
- image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 0.0))
- batch_sampler.append(
- image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 0.0))
- batch_sampler.append(
- image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 0.0))
- batch_sampler.append(
- image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 0.0))
- batch_sampler.append(
- image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 0.0))
- batch_sampler.append(
- image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0))
- """ random crop """
- sampled_bbox = image_util.generate_batch_samples(
- batch_sampler, bbox_labels, img_width, img_height)
-
- img = np.array(img)
- if len(sampled_bbox) > 0:
- idx = int(random.uniform(0, len(sampled_bbox)))
- img, sample_labels = image_util.crop_image(
- img, bbox_labels, sampled_bbox[idx], img_width,
- img_height)
-
- img = Image.fromarray(img)
- img = img.resize((settings.resize_w, settings.resize_h),
- Image.ANTIALIAS)
- img = np.array(img)
-
- if mode == 'train':
- mirror = int(random.uniform(0, 2))
- if mirror == 1:
- img = img[:, ::-1, :]
- for i in xrange(len(sample_labels)):
- tmp = sample_labels[i][1]
- sample_labels[i][1] = 1 - sample_labels[i][3]
- sample_labels[i][3] = 1 - tmp
-
- # HWC to CHW
- if len(img.shape) == 3:
- img = np.swapaxes(img, 1, 2)
- img = np.swapaxes(img, 1, 0)
- # RBG to BGR
- img = img[[2, 1, 0], :, :]
- img = img.astype('float32')
- img -= settings.img_mean
- img = img.flatten()
- img = img * 0.007843
+ image_name = image['file_name']
+ image_path = os.path.join(settings.data_dir, image_name)
+
+ im = Image.open(image_path)
+ if im.mode == 'L':
+ im = im.convert('RGB')
+ im_width, im_height = im.size
+
+ # layout: category_id | xmin | ymin | xmax | ymax | iscrowd |
+ # origin_coco_bbox | segmentation | area | image_id | annotation_id
+ bbox_labels = []
+ annIds = coco.getAnnIds(imgIds=image['id'])
+ anns = coco.loadAnns(annIds)
+ for ann in anns:
+ bbox_sample = []
+ # start from 1, leave 0 to background
+ bbox_sample.append(
+ float(category_ids.index(ann['category_id'])) + 1)
+ bbox = ann['bbox']
+ xmin, ymin, w, h = bbox
+ xmax = xmin + w
+ ymax = ymin + h
+ bbox_sample.append(float(xmin) / im_width)
+ bbox_sample.append(float(ymin) / im_height)
+ bbox_sample.append(float(xmax) / im_width)
+ bbox_sample.append(float(ymax) / im_height)
+ bbox_sample.append(float(ann['iscrowd']))
+ bbox_labels.append(bbox_sample)
+ im, sample_labels = preprocess(im, bbox_labels, mode, settings)
+ sample_labels = np.array(sample_labels)
+ if len(sample_labels) == 0: continue
+ im = im.astype('float32')
+ boxes = sample_labels[:, 1:5]
+ lbls = sample_labels[:, 0].astype('int32')
+ difficults = sample_labels[:, -1].astype('int32')
+ yield im, boxes, lbls, difficults
+
+ return reader
+
+
+def pascalvoc(settings, file_list, mode, shuffle):
+ flist = open(file_list)
+ images = [line.strip() for line in flist]
+ if not settings.toy == 0:
+ images = images[:settings.toy] if len(images) > settings.toy else images
+ print("{} on {} with {} images".format(mode, settings.dataset, len(images)))
+ def reader():
+ if mode == 'train' and shuffle:
+ random.shuffle(images)
+ for image in images:
+ image_path, label_path = image.split()
+ image_path = os.path.join(settings.data_dir, image_path)
+ label_path = os.path.join(settings.data_dir, label_path)
+
+ im = Image.open(image_path)
+ if im.mode == 'L':
+ im = im.convert('RGB')
+ im_width, im_height = im.size
+
+ # layout: label | xmin | ymin | xmax | ymax | difficult
+ bbox_labels = []
+ root = xml.etree.ElementTree.parse(label_path).getroot()
+ for object in root.findall('object'):
+ bbox_sample = []
+ # start from 1
+ bbox_sample.append(
+ float(settings.label_list.index(object.find('name').text)))
+ bbox = object.find('bndbox')
+ difficult = float(object.find('difficult').text)
+ bbox_sample.append(float(bbox.find('xmin').text) / im_width)
+ bbox_sample.append(float(bbox.find('ymin').text) / im_height)
+ bbox_sample.append(float(bbox.find('xmax').text) / im_width)
+ bbox_sample.append(float(bbox.find('ymax').text) / im_height)
+ bbox_sample.append(difficult)
+ bbox_labels.append(bbox_sample)
+ im, sample_labels = preprocess(im, bbox_labels, mode, settings)
sample_labels = np.array(sample_labels)
- if mode == 'train' or mode == 'test':
- if mode == 'train' and len(sample_labels) == 0: continue
- if mode == 'test' and len(sample_labels) == 0: continue
- yield img.astype(
- 'float32'
- ), sample_labels[:, 1:5], sample_labels[:, 0].astype(
- 'int32'), sample_labels[:, -1].astype('int32')
- elif mode == 'infer':
- yield img.astype('float32')
+ if len(sample_labels) == 0: continue
+ im = im.astype('float32')
+ boxes = sample_labels[:, 1:5]
+ lbls = sample_labels[:, 0].astype('int32')
+ difficults = sample_labels[:, -1].astype('int32')
+ yield im, boxes, lbls, difficults
return reader
@@ -301,9 +308,9 @@ def train(settings, file_list, shuffle=True):
elif '2017' in file_list:
sub_dir = "train2017"
train_settings.data_dir = os.path.join(settings.data_dir, sub_dir)
- return _reader_creator(train_settings, file_list, 'train', shuffle)
- elif settings.dataset == 'pascalvoc':
- return _reader_creator(settings, file_list, 'train', shuffle)
+ return coco(train_settings, file_list, 'train', shuffle)
+ else:
+ return pascalvoc(settings, file_list, 'train', shuffle)
def test(settings, file_list):
@@ -315,10 +322,29 @@ def test(settings, file_list):
elif '2017' in file_list:
sub_dir = "val2017"
test_settings.data_dir = os.path.join(settings.data_dir, sub_dir)
- return _reader_creator(test_settings, file_list, 'test', False)
- elif settings.dataset == 'pascalvoc':
- return _reader_creator(settings, file_list, 'test', False)
+ return coco(test_settings, file_list, 'test', False)
+ else:
+ return pascalvoc(settings, file_list, 'test', False)
-def infer(settings, file_list):
- return _reader_creator(settings, file_list, 'infer', False)
+def infer(settings, image_path):
+ def reader():
+ im = Image.open(image_path)
+ if im.mode == 'L':
+ im = im.convert('RGB')
+ im_width, im_height = im.size
+ img = img.resize((settings.resize_w, settings.resize_h),
+ Image.ANTIALIAS)
+ img = np.array(img)
+ # HWC to CHW
+ if len(img.shape) == 3:
+ img = np.swapaxes(img, 1, 2)
+ img = np.swapaxes(img, 1, 0)
+ # RBG to BGR
+ img = img[[2, 1, 0], :, :]
+ img = img.astype('float32')
+ img -= settings.img_mean
+ img = img * 0.007843
+ yield img
+
+ return reader
diff --git a/fluid/object_detection/train.py b/fluid/object_detection/train.py
index 0f2856ca14cd155f600f4cf23a3403262b5bb110..71fa61322d5e58e6726796463b559ccc1e584d7a 100644
--- a/fluid/object_detection/train.py
+++ b/fluid/object_detection/train.py
@@ -1,36 +1,38 @@
-import paddle
-import paddle.fluid as fluid
-import reader
-import load_model as load_model
-from mobilenet_ssd import mobile_net
-from utility import add_arguments, print_arguments
import os
import time
import numpy as np
import argparse
import functools
+import shutil
+
+import paddle
+import paddle.fluid as fluid
+import reader
+from mobilenet_ssd import mobile_net
+from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('batch_size', int, 32, "Minibatch size.")
-add_arg('num_passes', int, 25, "Epoch number.")
+add_arg('num_passes', int, 120, "Epoch number.")
add_arg('parallel', bool, True, "Whether use parallel training.")
-add_arg('use_gpu', bool, True, "Whether use GPU.")
-add_arg('use_nccl', bool, False, "Whether use NCCL.")
+add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
+add_arg('use_nccl', bool, False, "Whether to use NCCL or not.")
add_arg('dataset', str, 'pascalvoc', "coco or pascalvoc.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, 'pretrained/ssd_mobilenet_v1_coco/', "The init model path.")
-add_arg('apply_distort', bool, True, "Whether apply distort")
-add_arg('apply_expand', bool, False, "Whether appley expand")
-add_arg('resize_h', int, 300, "resize image size")
-add_arg('resize_w', int, 300, "resize image size")
-add_arg('mean_value_B', float, 127.5, "mean value which will be subtracted") #123.68
-add_arg('mean_value_G', float, 127.5, "mean value which will be subtracted") #116.78
-add_arg('mean_value_R', float, 127.5, "mean value which will be subtracted") #103.94
+add_arg('apply_distort', bool, True, "Whether apply distort")
+add_arg('apply_expand', bool, True, "Whether appley expand")
+add_arg('ap_version', str, '11point', "11point or integral")
+add_arg('resize_h', int, 300, "The resized image height.")
+add_arg('resize_w', int, 300, "The resized image width.")
+add_arg('mean_value_B', float, 127.5, "mean value for B channel which will be subtracted") #123.68
+add_arg('mean_value_G', float, 127.5, "mean value for G channel which will be subtracted") #116.78
+add_arg('mean_value_R', float, 127.5, "mean value for R channel which will be subtracted") #103.94
add_arg('is_toy', int, 0, "Toy for quick debug, 0 means using all data, while n means using only n sample")
-# yapf: disable
+# yapf: enable
def parallel_do(args,
@@ -94,7 +96,7 @@ def parallel_do(args,
num_classes,
overlap_threshold=0.5,
evaluate_difficult=False,
- ap_version='integral')
+ ap_version=args.ap_version)
if data_args.dataset == 'coco':
# learning rate decay in 12, 19 pass, respectively
@@ -116,8 +118,10 @@ def parallel_do(args,
exe.run(fluid.default_startup_program())
if pretrained_model:
+
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
+
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
train_reader = paddle.batch(
@@ -131,7 +135,7 @@ def parallel_do(args,
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
test_map = None
- for _, data in enumerate(test_reader()):
+ for data in test_reader():
test_map = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
@@ -174,6 +178,9 @@ def parallel_exe(args,
elif data_args.dataset == 'pascalvoc':
num_classes = 21
+ devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
+ devices_num = len(devices.split(","))
+
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
gt_box = fluid.layers.data(
name='gt_box', shape=[4], dtype='float32', lod_level=1)
@@ -185,8 +192,7 @@ def parallel_exe(args,
locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
- loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box,
- box_var)
+ loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box, box_var)
loss = fluid.layers.reduce_sum(loss)
test_program = fluid.default_main_program().clone(for_test=True)
@@ -199,17 +205,23 @@ def parallel_exe(args,
num_classes,
overlap_threshold=0.5,
evaluate_difficult=False,
- ap_version='integral')
+ ap_version=args.ap_version)
if data_args.dataset == 'coco':
# learning rate decay in 12, 19 pass, respectively
if '2014' in train_file_list:
- boundaries = [82783 / batch_size * 12, 82783 / batch_size * 19]
+ epocs = 82783 / batch_size
+ boundaries = [epocs * 12, epocs * 19]
elif '2017' in train_file_list:
- boundaries = [118287 / batch_size * 12, 118287 / batch_size * 19]
+ epocs = 118287 / batch_size
+ boundaries = [epcos * 12, epocs * 19]
elif data_args.dataset == 'pascalvoc':
- boundaries = [40000, 60000]
- values = [learning_rate, learning_rate * 0.5, learning_rate * 0.25]
+ epocs = 19200 / batch_size
+ boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100]
+ values = [
+ learning_rate, learning_rate * 0.5, learning_rate * 0.25,
+ learning_rate * 0.1, learning_rate * 0.01
+ ]
optimizer = fluid.optimizer.RMSProp(
learning_rate=fluid.layers.piecewise_decay(boundaries, values),
regularization=fluid.regularizer.L2Decay(0.00005), )
@@ -221,12 +233,15 @@ def parallel_exe(args,
exe.run(fluid.default_startup_program())
if pretrained_model:
+
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
+
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
- train_exe = fluid.ParallelExecutor(use_cuda=args.use_gpu,
- loss_name=loss.name)
+ if args.parallel:
+ train_exe = fluid.ParallelExecutor(
+ use_cuda=args.use_gpu, loss_name=loss.name)
train_reader = paddle.batch(
reader.train(data_args, train_file_list), batch_size=batch_size)
@@ -235,36 +250,53 @@ def parallel_exe(args,
feeder = fluid.DataFeeder(
place=place, feed_list=[image, gt_box, gt_label, difficult])
- def test(pass_id):
+ def save_model(postfix):
+ model_path = os.path.join(model_save_dir, postfix)
+ if os.path.isdir(model_path):
+ shutil.rmtree(model_path)
+ print 'save models to %s' % (model_path)
+ fluid.io.save_persistables(exe, model_path)
+
+ best_map = 0.
+
+ def test(pass_id, best_map):
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
test_map = None
- for _, data in enumerate(test_reader()):
+ for data in test_reader():
test_map = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
+ if test_map[0] > best_map:
+ best_map = test_map[0]
+ save_model('best_model')
print("Test {0}, map {1}".format(pass_id, test_map[0]))
for pass_id in range(num_passes):
start_time = time.time()
prev_start_time = start_time
end_time = 0
- test(pass_id)
for batch_id, data in enumerate(train_reader()):
prev_start_time = start_time
start_time = time.time()
- loss_v, = train_exe.run(fetch_list=[loss.name],
- feed_dict=feeder.feed(data))
+ if len(data) < devices_num: continue
+ if args.parallel:
+ loss_v, = train_exe.run(fetch_list=[loss.name],
+ feed_dict=feeder.feed(data))
+ else:
+ loss_v, = exe.run(fluid.default_main_program(),
+ feed=feeder.feed(data),
+ fetch_list=[loss])
end_time = time.time()
loss_v = np.mean(np.array(loss_v))
if batch_id % 20 == 0:
print("Pass {0}, batch {1}, loss {2}, time {3}".format(
pass_id, batch_id, loss_v, start_time - prev_start_time))
-
+ test(pass_id, best_map)
if pass_id % 10 == 0 or pass_id == num_passes - 1:
- model_path = os.path.join(model_save_dir, str(pass_id))
- print 'save models to %s' % (model_path)
- fluid.io.save_persistables(exe, model_path)
+ save_model(str(pass_id))
+ print("Best test map {0}".format(best_map))
+
if __name__ == '__main__':
args = parser.parse_args()
@@ -283,22 +315,23 @@ if __name__ == '__main__':
data_args = reader.Settings(
dataset=args.dataset,
- toy=args.is_toy,
data_dir=data_dir,
label_file=label_file,
apply_distort=args.apply_distort,
apply_expand=args.apply_expand,
resize_h=args.resize_h,
resize_w=args.resize_w,
- mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R])
+ mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
+ toy=args.is_toy)
#method = parallel_do
method = parallel_exe
- method(args,
- train_file_list=train_file_list,
- val_file_list=val_file_list,
- data_args=data_args,
- learning_rate=args.learning_rate,
- batch_size=args.batch_size,
- num_passes=args.num_passes,
- model_save_dir=model_save_dir,
- pretrained_model=args.pretrained_model)
+ method(
+ args,
+ train_file_list=train_file_list,
+ val_file_list=val_file_list,
+ data_args=data_args,
+ learning_rate=args.learning_rate,
+ batch_size=args.batch_size,
+ num_passes=args.num_passes,
+ model_save_dir=model_save_dir,
+ pretrained_model=args.pretrained_model)
diff --git a/fluid/text_classification/README.md b/fluid/text_classification/README.md
index 500ee6ae6db28e9d844d206a1cc894c36f1db09f..43c15934fa62af3db2261be37803ce21ba6bf946 100644
--- a/fluid/text_classification/README.md
+++ b/fluid/text_classification/README.md
@@ -1,16 +1,112 @@
-The minimum PaddlePaddle version needed for the code sample in this directory is the lastest develop branch. If you are on a version of PaddlePaddle earlier than this, [please update your installation](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html).
+# 文本分类
----
+以下是本例的简要目录结构及说明:
-# Text Classification
-
-## Data Preparation
-```
-wget http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz
-tar zxf aclImdb_v1.tar.gz
+```text
+.
+├── nets.py # 模型定义
+├── README.md # 文档
+├── train.py # 训练脚本
+├── infer.py # 预测脚本
+└── utils.py # 定义通用函数,从外部获取
```
-## Training
+
+## 简介,模型详解
+
+在PaddlePaddle v2版本[文本分类](https://github.com/PaddlePaddle/models/blob/develop/text/README.md)中对于文本分类任务有较详细的介绍,在本例中不再重复介绍。
+在模型上,我们采用了bow, cnn, lstm, gru四种常见的文本分类模型。
+
+## 训练
+
+1. 运行命令 `python train.py bow` 开始训练模型。
+ ```python
+ python train.py bow # bow指定网络结构,可替换成cnn, lstm, gru
+ ```
+
+2. (可选)想自定义网络结构,需在[nets.py](./nets.py)中自行添加,并设置[train.py](./train.py)中的相应参数。
+ ```python
+ def train(train_reader, # 训练数据
+ word_dict, # 数据字典
+ network, # 模型配置
+ use_cuda, # 是否用GPU
+ parallel, # 是否并行
+ save_dirname, # 保存模型路径
+ lr=0.2, # 学习率大小
+ batch_size=128, # 每个batch的样本数
+ pass_num=30): # 训练的轮数
+ ```
+
+## 训练结果示例
+```text
+ pass_id: 0, avg_acc: 0.848040, avg_cost: 0.354073
+ pass_id: 1, avg_acc: 0.914200, avg_cost: 0.217945
+ pass_id: 2, avg_acc: 0.929800, avg_cost: 0.184302
+ pass_id: 3, avg_acc: 0.938680, avg_cost: 0.164240
+ pass_id: 4, avg_acc: 0.945120, avg_cost: 0.149150
+ pass_id: 5, avg_acc: 0.951280, avg_cost: 0.137117
+ pass_id: 6, avg_acc: 0.955360, avg_cost: 0.126434
+ pass_id: 7, avg_acc: 0.961400, avg_cost: 0.117405
+ pass_id: 8, avg_acc: 0.963560, avg_cost: 0.110070
+ pass_id: 9, avg_acc: 0.965840, avg_cost: 0.103273
+ pass_id: 10, avg_acc: 0.969800, avg_cost: 0.096314
+ pass_id: 11, avg_acc: 0.971720, avg_cost: 0.090206
+ pass_id: 12, avg_acc: 0.974800, avg_cost: 0.084970
+ pass_id: 13, avg_acc: 0.977400, avg_cost: 0.078981
+ pass_id: 14, avg_acc: 0.980000, avg_cost: 0.073685
+ pass_id: 15, avg_acc: 0.981080, avg_cost: 0.069898
+ pass_id: 16, avg_acc: 0.982080, avg_cost: 0.064923
+ pass_id: 17, avg_acc: 0.984680, avg_cost: 0.060861
+ pass_id: 18, avg_acc: 0.985840, avg_cost: 0.057095
+ pass_id: 19, avg_acc: 0.988080, avg_cost: 0.052424
+ pass_id: 20, avg_acc: 0.989160, avg_cost: 0.049059
+ pass_id: 21, avg_acc: 0.990120, avg_cost: 0.045882
+ pass_id: 22, avg_acc: 0.992080, avg_cost: 0.042140
+ pass_id: 23, avg_acc: 0.992280, avg_cost: 0.039722
+ pass_id: 24, avg_acc: 0.992840, avg_cost: 0.036607
+ pass_id: 25, avg_acc: 0.994440, avg_cost: 0.034040
+ pass_id: 26, avg_acc: 0.995000, avg_cost: 0.031501
+ pass_id: 27, avg_acc: 0.995440, avg_cost: 0.028988
+ pass_id: 28, avg_acc: 0.996240, avg_cost: 0.026639
+ pass_id: 29, avg_acc: 0.996960, avg_cost: 0.024186
```
-python train.py --dict_path 'aclImdb/imdb.vocab'
+
+## 预测
+1. 运行命令 `python infer.py bow_model`, 开始预测。
+ ```python
+ python infer.py bow_model # bow_model指定需要导入的模型
+
+## 预测结果示例
+```text
+ model_path: bow_model/epoch0, avg_acc: 0.882800
+ model_path: bow_model/epoch1, avg_acc: 0.882360
+ model_path: bow_model/epoch2, avg_acc: 0.881400
+ model_path: bow_model/epoch3, avg_acc: 0.877800
+ model_path: bow_model/epoch4, avg_acc: 0.872920
+ model_path: bow_model/epoch5, avg_acc: 0.872640
+ model_path: bow_model/epoch6, avg_acc: 0.869960
+ model_path: bow_model/epoch7, avg_acc: 0.865160
+ model_path: bow_model/epoch8, avg_acc: 0.863680
+ model_path: bow_model/epoch9, avg_acc: 0.861200
+ model_path: bow_model/epoch10, avg_acc: 0.853520
+ model_path: bow_model/epoch11, avg_acc: 0.850400
+ model_path: bow_model/epoch12, avg_acc: 0.855960
+ model_path: bow_model/epoch13, avg_acc: 0.853480
+ model_path: bow_model/epoch14, avg_acc: 0.855960
+ model_path: bow_model/epoch15, avg_acc: 0.854120
+ model_path: bow_model/epoch16, avg_acc: 0.854160
+ model_path: bow_model/epoch17, avg_acc: 0.852240
+ model_path: bow_model/epoch18, avg_acc: 0.852320
+ model_path: bow_model/epoch19, avg_acc: 0.850280
+ model_path: bow_model/epoch20, avg_acc: 0.849760
+ model_path: bow_model/epoch21, avg_acc: 0.850160
+ model_path: bow_model/epoch22, avg_acc: 0.846800
+ model_path: bow_model/epoch23, avg_acc: 0.845440
+ model_path: bow_model/epoch24, avg_acc: 0.845640
+ model_path: bow_model/epoch25, avg_acc: 0.846200
+ model_path: bow_model/epoch26, avg_acc: 0.845880
+ model_path: bow_model/epoch27, avg_acc: 0.844880
+ model_path: bow_model/epoch28, avg_acc: 0.844680
+ model_path: bow_model/epoch29, avg_acc: 0.844960
```
+注:过拟合导致acc持续下降,请忽略
diff --git a/fluid/text_classification/config.py b/fluid/text_classification/config.py
deleted file mode 100644
index 2aba3247eb9033d959bbf4a7c3d475d5c8309058..0000000000000000000000000000000000000000
--- a/fluid/text_classification/config.py
+++ /dev/null
@@ -1,16 +0,0 @@
-class TrainConfig(object):
-
- # Whether to use GPU in training or not.
- use_gpu = False
-
- # The training batch size.
- batch_size = 4
-
- # The epoch number.
- num_passes = 30
-
- # The global learning rate.
- learning_rate = 0.01
-
- # Training log will be printed every log_period.
- log_period = 100
diff --git a/fluid/text_classification/infer.py b/fluid/text_classification/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2a0363d786866a92195dba8b490287b3ca9bc9d
--- /dev/null
+++ b/fluid/text_classification/infer.py
@@ -0,0 +1,50 @@
+import sys
+import time
+import unittest
+import contextlib
+import numpy as np
+
+import paddle.fluid as fluid
+import paddle.v2 as paddle
+
+import utils
+
+
+def infer(test_reader, use_cuda, model_path=None):
+ """
+ inference function
+ """
+ if model_path is None:
+ print(str(model_path) + " cannot be found")
+ return
+
+ place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
+ exe = fluid.Executor(place)
+
+ inference_scope = fluid.core.Scope()
+ with fluid.scope_guard(inference_scope):
+ [inference_program, feed_target_names,
+ fetch_targets] = fluid.io.load_inference_model(model_path, exe)
+
+ total_acc = 0.0
+ total_count = 0
+ for data in test_reader():
+ acc = exe.run(inference_program,
+ feed=utils.data2tensor(data, place),
+ fetch_list=fetch_targets,
+ return_numpy=True)
+ total_acc += acc[0] * len(data)
+ total_count += len(data)
+
+ avg_acc = total_acc / total_count
+ print("model_path: %s, avg_acc: %f" % (model_path, avg_acc))
+
+
+if __name__ == "__main__":
+ word_dict, train_reader, test_reader = utils.prepare_data(
+ "imdb", self_dict=False, batch_size=128, buf_size=50000)
+
+ model_path = sys.argv[1]
+ for i in range(30):
+ epoch_path = model_path + "/" + "epoch" + str(i)
+ infer(test_reader, use_cuda=False, model_path=epoch_path)
diff --git a/fluid/text_classification/nets.py b/fluid/text_classification/nets.py
new file mode 100644
index 0000000000000000000000000000000000000000..a21742d22d0bd1676c8c5874899af746b5225636
--- /dev/null
+++ b/fluid/text_classification/nets.py
@@ -0,0 +1,124 @@
+import sys
+import time
+import numpy as np
+
+import paddle.fluid as fluid
+import paddle.v2 as paddle
+
+
+def bow_net(data,
+ label,
+ dict_dim,
+ emb_dim=128,
+ hid_dim=128,
+ hid_dim2=96,
+ class_dim=2):
+ """
+ bow net
+ """
+ emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
+ bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
+ bow_tanh = fluid.layers.tanh(bow)
+ fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
+ fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
+ prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
+ cost = fluid.layers.cross_entropy(input=prediction, label=label)
+ avg_cost = fluid.layers.mean(x=cost)
+ acc = fluid.layers.accuracy(input=prediction, label=label)
+
+ return avg_cost, acc, prediction
+
+
+def cnn_net(data,
+ label,
+ dict_dim,
+ emb_dim=128,
+ hid_dim=128,
+ hid_dim2=96,
+ class_dim=2,
+ win_size=3):
+ """
+ conv net
+ """
+ emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
+
+ conv_3 = fluid.nets.sequence_conv_pool(
+ input=emb,
+ num_filters=hid_dim,
+ filter_size=win_size,
+ act="tanh",
+ pool_type="max")
+
+ fc_1 = fluid.layers.fc(input=[conv_3], size=hid_dim2)
+
+ prediction = fluid.layers.fc(input=[fc_1], size=class_dim, act="softmax")
+ cost = fluid.layers.cross_entropy(input=prediction, label=label)
+ avg_cost = fluid.layers.mean(x=cost)
+ acc = fluid.layers.accuracy(input=prediction, label=label)
+
+ return avg_cost, acc, prediction
+
+
+def lstm_net(data,
+ label,
+ dict_dim,
+ emb_dim=128,
+ hid_dim=128,
+ hid_dim2=96,
+ class_dim=2,
+ emb_lr=30.0):
+ """
+ lstm net
+ """
+ emb = fluid.layers.embedding(
+ input=data,
+ size=[dict_dim, emb_dim],
+ param_attr=fluid.ParamAttr(learning_rate=emb_lr))
+
+ fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4, act='tanh')
+
+ lstm_h, c = fluid.layers.dynamic_lstm(
+ input=fc0, size=hid_dim * 4, is_reverse=False)
+
+ lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max')
+ lstm_max_tanh = fluid.layers.tanh(lstm_max)
+
+ fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh')
+
+ prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
+
+ cost = fluid.layers.cross_entropy(input=prediction, label=label)
+ avg_cost = fluid.layers.mean(x=cost)
+ acc = fluid.layers.accuracy(input=prediction, label=label)
+
+ return avg_cost, acc, prediction
+
+
+def gru_net(data,
+ label,
+ dict_dim,
+ emb_dim=128,
+ hid_dim=128,
+ hid_dim2=96,
+ class_dim=2,
+ emb_lr=400.0):
+ """
+ gru net
+ """
+ emb = fluid.layers.embedding(
+ input=data,
+ size=[dict_dim, emb_dim],
+ param_attr=fluid.ParamAttr(learning_rate=emb_lr))
+
+ fc0 = fluid.layers.fc(input=emb, size=hid_dim * 3)
+ gru_h = fluid.layers.dynamic_gru(input=fc0, size=hid_dim, is_reverse=False)
+ gru_max = fluid.layers.sequence_pool(input=gru_h, pool_type='max')
+ gru_max_tanh = fluid.layers.tanh(gru_max)
+ fc1 = fluid.layers.fc(input=gru_max_tanh, size=hid_dim2, act='tanh')
+ prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
+
+ cost = fluid.layers.cross_entropy(input=prediction, label=label)
+ avg_cost = fluid.layers.mean(x=cost)
+ acc = fluid.layers.accuracy(input=prediction, label=label)
+
+ return avg_cost, acc, prediction
diff --git a/fluid/text_classification/train.py b/fluid/text_classification/train.py
index d32e1c4c878f4d6ef554cc27e0fc5ffc99f96a4a..dc164671e785b758365885b98788fae71d5f8a87 100644
--- a/fluid/text_classification/train.py
+++ b/fluid/text_classification/train.py
@@ -1,164 +1,131 @@
-import numpy as np
import sys
-import os
-import argparse
import time
+import unittest
+import contextlib
-import paddle.v2 as paddle
import paddle.fluid as fluid
+import paddle.v2 as paddle
-from config import TrainConfig as conf
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--dict_path',
- type=str,
- required=True,
- help="Path of the word dictionary.")
- return parser.parse_args()
-
-
-# Define to_lodtensor function to process the sequential data.
-def to_lodtensor(data, place):
- seq_lens = [len(seq) for seq in data]
- cur_len = 0
- lod = [cur_len]
- for l in seq_lens:
- cur_len += l
- lod.append(cur_len)
- flattened_data = np.concatenate(data, axis=0).astype("int64")
- flattened_data = flattened_data.reshape([len(flattened_data), 1])
- res = fluid.LoDTensor()
- res.set(flattened_data, place)
- res.set_lod([lod])
- return res
-
-
-# Load the dictionary.
-def load_vocab(filename):
- vocab = {}
- with open(filename) as f:
- for idx, line in enumerate(f):
- vocab[line.strip()] = idx
- return vocab
-
-
-# Define the convolution model.
-def conv_net(dict_dim,
- window_size=3,
- emb_dim=128,
- num_filters=128,
- fc0_dim=96,
- class_dim=2):
-
+import utils
+from nets import bow_net
+from nets import cnn_net
+from nets import lstm_net
+from nets import gru_net
+
+
+def train(train_reader,
+ word_dict,
+ network,
+ use_cuda,
+ parallel,
+ save_dirname,
+ lr=0.2,
+ batch_size=128,
+ pass_num=30):
+ """
+ train network
+ """
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
- emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
-
- conv_3 = fluid.nets.sequence_conv_pool(
- input=emb,
- num_filters=num_filters,
- filter_size=window_size,
- act="tanh",
- pool_type="max")
-
- fc_0 = fluid.layers.fc(input=[conv_3], size=fc0_dim)
-
- prediction = fluid.layers.fc(input=[fc_0], size=class_dim, act="softmax")
-
- cost = fluid.layers.cross_entropy(input=prediction, label=label)
-
- avg_cost = fluid.layers.mean(x=cost)
-
- return data, label, prediction, avg_cost
-
-
-def main(dict_path):
- word_dict = load_vocab(dict_path)
- word_dict["