提交 7857d7f7 编写于 作者: D dangqingqing

Bug fix for image classification demo and ResNet model zoo when using CPU.

But need to check consistency of CPU and GPU later for conv layer.
ISSUE=4592155

git-svn-id: https://svn.baidu.com/idl/trunk/paddle@1427 1ad973e4-5ce8-4261-8a94-b56d1f490c56
上级 57109591
...@@ -33,6 +33,7 @@ logging.getLogger().setLevel(logging.INFO) ...@@ -33,6 +33,7 @@ logging.getLogger().setLevel(logging.INFO)
class ImageClassifier(): class ImageClassifier():
def __init__(self, train_conf, model_dir=None, def __init__(self, train_conf, model_dir=None,
resize_dim=256, crop_dim=224, resize_dim=256, crop_dim=224,
use_gpu=True,
mean_file=None, mean_file=None,
output_layer=None, output_layer=None,
oversample=False, is_color=True): oversample=False, is_color=True):
...@@ -76,9 +77,9 @@ class ImageClassifier(): ...@@ -76,9 +77,9 @@ class ImageClassifier():
# this three mean value is calculated from ImageNet. # this three mean value is calculated from ImageNet.
self.transformer.set_mean(np.array([103.939,116.779,123.68])) self.transformer.set_mean(np.array([103.939,116.779,123.68]))
conf_args = "is_test=1,use_gpu=1,is_predict=1" conf_args = "is_test=1,use_gpu=%d,is_predict=1" % (int(use_gpu))
conf = parse_config(train_conf, conf_args) conf = parse_config(train_conf, conf_args)
swig_paddle.initPaddle("--use_gpu=1") swig_paddle.initPaddle("--use_gpu=%d" % (int(use_gpu)))
self.network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config) self.network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config)
assert isinstance(self.network, swig_paddle.GradientMachine) assert isinstance(self.network, swig_paddle.GradientMachine)
self.network.loadParameters(self.model_dir) self.network.loadParameters(self.model_dir)
...@@ -236,6 +237,9 @@ def option_parser(): ...@@ -236,6 +237,9 @@ def option_parser():
parser.add_option("-w", "--model", parser.add_option("-w", "--model",
action="store", dest="model_path", action="store", dest="model_path",
default=None, help="model path") default=None, help="model path")
parser.add_option("-g", "--use_gpu", action="store",
dest="use_gpu", default=True,
help="Whether to use gpu mode.")
parser.add_option("-o", "--output_dir", parser.add_option("-o", "--output_dir",
action="store", dest="output_dir", action="store", dest="output_dir",
default="output", help="output path") default="output", help="output path")
...@@ -259,10 +263,11 @@ def main(): ...@@ -259,10 +263,11 @@ def main():
""" """
options, args = option_parser() options, args = option_parser()
obj = ImageClassifier(options.train_conf, obj = ImageClassifier(options.train_conf,
options.model_path, options.model_path,
mean_file=options.mean, use_gpu=options.use_gpu,
output_layer=options.output_layer, mean_file=options.mean,
oversample=options.multi_crop) output_layer=options.output_layer,
oversample=options.multi_crop)
if options.job_type == "predict": if options.job_type == "predict":
obj.predict(options.data_file) obj.predict(options.data_file)
......
...@@ -14,11 +14,16 @@ ...@@ -14,11 +14,16 @@
# limitations under the License. # limitations under the License.
set -e set -e
#Note if you use CPU mode, you need to set use_gpu=0 in classify.py. like this:
#conf_args = "is_test=0,use_gpu=1,is_predict=1"
#conf = parse_config(train_conf, conf_args)
#swig_paddle.initPaddle("--use_gpu=0")
python classify.py \ python classify.py \
--job=extract \ --job=extract \
--conf=resnet.py \ --conf=resnet.py \
--use_gpu=1 \
--mean=model/mean_meta_224/mean.meta \ --mean=model/mean_meta_224/mean.meta \
--model=model/resnet_50 \ --model=model/resnet_50 \
--data=./example/test.list \ --data=./example/test.list \
--output_layer="res5_3_branch2c_conv,res5_3_branch2c_bn" \ --output_layer="res5_3_branch2c_conv,res5_3_branch2c_bn" \
--output_dir=features --output_dir=features
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
set -e
:' :'
Visual deep residual network Visual deep residual network
...@@ -23,6 +22,8 @@ Usage: ...@@ -23,6 +22,8 @@ Usage:
./net_diagram.sh ./net_diagram.sh
' '
set -e
DIR="$( cd "$(dirname "$0")" ; pwd -P )" DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd $DIR cd $DIR
......
...@@ -19,4 +19,5 @@ python classify.py \ ...@@ -19,4 +19,5 @@ python classify.py \
--conf=resnet.py\ --conf=resnet.py\
--model=model/resnet_50 \ --model=model/resnet_50 \
--multi_crop \ --multi_crop \
--use_gpu=1 \
--data=./example/test.list --data=./example/test.list
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
# limitations under the License. # limitations under the License.
set -e set -e
config=trainer_config.py #Note the default model is pass-00002, you shold make sure the model path
#exists or change the mode path.
model=model_output/pass-00002/ model=model_output/pass-00002/
config=trainer_config.py
label=data/pre-imdb/labels.list label=data/pre-imdb/labels.list
python predict.py \ python predict.py \
-n $config\ -n $config\
......
...@@ -11,7 +11,7 @@ First, download CIFAR-10 dataset. CIFAR-10 dataset can be downloaded from its of ...@@ -11,7 +11,7 @@ First, download CIFAR-10 dataset. CIFAR-10 dataset can be downloaded from its of
<https://www.cs.toronto.edu/~kriz/cifar.html> <https://www.cs.toronto.edu/~kriz/cifar.html>
We have prepared a script to download and process CIFAR-10 dataset. The script will download CIFAR-10 dataset from the official dataset. We have prepared a script to download and process CIFAR-10 dataset. The script will download CIFAR-10 dataset from the official dataset.
It will convert it to jpeg images and organize them into a directory with the required structure for the tutorial. Make sure that you have installed the python dependency (PIL). It will convert it to jpeg images and organize them into a directory with the required structure for the tutorial. Make sure that you have installed the python dependency (PIL). If not, you can install it by `pip install PIL` and if you have installed `pip` package.
```bash ```bash
cd demo/image_classification/data/ cd demo/image_classification/data/
......
...@@ -223,6 +223,7 @@ extract_fea_py.sh: ...@@ -223,6 +223,7 @@ extract_fea_py.sh:
python classify.py \ python classify.py \
--job=extract \ --job=extract \
--conf=resnet.py\ --conf=resnet.py\
--use_gpu=1 \
--mean=model/mean_meta_224/mean.meta \ --mean=model/mean_meta_224/mean.meta \
--model=model/resnet_50 \ --model=model/resnet_50 \
--data=./example/test.list \ --data=./example/test.list \
...@@ -230,12 +231,15 @@ python classify.py \ ...@@ -230,12 +231,15 @@ python classify.py \
--output_dir=features --output_dir=features
``` ```
* --job=extract: specify job mode to extract feature. * \--job=extract: specify job mode to extract feature.
* --conf=resnet.py: network configure. * \--conf=resnet.py: network configure.
* --model=model/resnet_5: model path. * \--use_gpu=1: speficy GPU mode.
* --data=./example/test.list: data list. * \--model=model/resnet_5: model path.
* --output_layer="xxx,xxx": specify layers to extract features. * \--data=./example/test.list: data list.
* --output_dir=features: output diretcoty. * \--output_layer="xxx,xxx": specify layers to extract features.
* \--output_dir=features: output diretcoty.
Note, since the convolution layer in these ResNet models is suitable for the cudnn implementation which only support GPU. It not support CPU mode because of compatibility issue and we will fix later.
If run successfully, you will see features saved in `features/batch_0`, this file is produced with cPickle. You can use `load_feature_py` interface in `load_feature.py` to open the file, and it returns a dictionary as follows: If run successfully, you will see features saved in `features/batch_0`, this file is produced with cPickle. You can use `load_feature_py` interface in `load_feature.py` to open the file, and it returns a dictionary as follows:
...@@ -265,13 +269,15 @@ python classify.py \ ...@@ -265,13 +269,15 @@ python classify.py \
--conf=resnet.py\ --conf=resnet.py\
--multi_crop \ --multi_crop \
--model=model/resnet_50 \ --model=model/resnet_50 \
--use_gpu=1 \
--data=./example/test.list --data=./example/test.list
``` ```
* --job=extract: speficy job mode to predict. * \--job=extract: speficy job mode to predict.
* --conf=resnet.py: network configure. * \--conf=resnet.py: network configure.
* --multi_crop: use 10 crops and average predicting probability. * \--multi_crop: use 10 crops and average predicting probability.
* --model=model/resnet_50: model path. * \--use_gpu=1: speficy GPU mode.
* --data=./example/test.list: data list. * \--model=model/resnet_50: model path.
* \--data=./example/test.list: data list.
If run successfully, you will see following results, where 156 and 285 are labels of the images. If run successfully, you will see following results, where 156 and 285 are labels of the images.
......
...@@ -204,15 +204,15 @@ paddle train --config=$config \ ...@@ -204,15 +204,15 @@ paddle train --config=$config \
2>&1 | tee 'train.log' 2>&1 | tee 'train.log'
``` ```
* --config=$config: set network config. * \--config=$config: set network config.
* --save\_dir=$output: set output path to save models. * \--save\_dir=$output: set output path to save models.
* --job=train: set job mode to train. * \--job=train: set job mode to train.
* --use\_gpu=false: use CPU to train, set true, if you install GPU version of PaddlePaddle and want to use GPU to train. * \--use\_gpu=false: use CPU to train, set true, if you install GPU version of PaddlePaddle and want to use GPU to train.
* --trainer\_count=4: set thread number (or GPU count). * \--trainer\_count=4: set thread number (or GPU count).
* --num\_passes=15: set pass number, one pass in PaddlePaddle means training all samples in dataset one time. * \--num\_passes=15: set pass number, one pass in PaddlePaddle means training all samples in dataset one time.
* --log\_period=20: print log every 20 batches. * \--log\_period=20: print log every 20 batches.
* --show\_parameter\_stats\_period=100: show parameter statistic every 100 batches. * \--show\_parameter\_stats\_period=100: show parameter statistic every 100 batches.
* --test\_all_data\_in\_one\_period=1: test all data every testing. * \--test\_all_data\_in\_one\_period=1: test all data every testing.
If the run succeeds, the output log is saved in path of `demo/sentiment/train.log` and model is saved in path of `demo/sentiment/model_output/`. The output log is explained as follows. If the run succeeds, the output log is saved in path of `demo/sentiment/train.log` and model is saved in path of `demo/sentiment/model_output/`. The output log is explained as follows.
...@@ -286,8 +286,10 @@ cd demo/sentiment ...@@ -286,8 +286,10 @@ cd demo/sentiment
predict.sh: predict.sh:
``` ```
config=trainer_config.py #Note the default model is pass-00002, you shold make sure the model path
#exists or change the mode path.
model=model_output/pass-00002/ model=model_output/pass-00002/
config=trainer_config.py
label=data/pre-imdb/labels.list label=data/pre-imdb/labels.list
python predict.py \ python predict.py \
-n $config\ -n $config\
...@@ -304,6 +306,9 @@ python predict.py \ ...@@ -304,6 +306,9 @@ python predict.py \
* -d data/pre-imdb/dict.txt: set dictionary. * -d data/pre-imdb/dict.txt: set dictionary.
* -i data/aclImdb/test/pos/10014_7.txt: set one example file to predict. * -i data/aclImdb/test/pos/10014_7.txt: set one example file to predict.
Note you should make sure the default model path `model_output/pass-00002`
exists or change the model path.
Predicting result of this example: Predicting result of this example:
``` ```
......
...@@ -1157,7 +1157,8 @@ void CpuMatrix::copyFrom(const Matrix& src) { ...@@ -1157,7 +1157,8 @@ void CpuMatrix::copyFrom(const Matrix& src) {
CHECK(elementCnt_ == src.getElementCnt()); CHECK(elementCnt_ == src.getElementCnt());
hl_memcpy_device2host(data_, const_cast<real*>(src.getData()), hl_memcpy_device2host(data_, const_cast<real*>(src.getData()),
sizeof(real) * elementCnt_); sizeof(real) * elementCnt_);
} else if (typeid(src) == typeid(CpuMatrix)) { } else if (typeid(src) == typeid(CpuMatrix) ||
typeid(src) == typeid(SharedCpuMatrix)) {
CHECK(src.isContiguous()); CHECK(src.isContiguous());
CHECK(elementCnt_ == src.getElementCnt()); CHECK(elementCnt_ == src.getElementCnt());
memcpy(data_, src.getData(), sizeof(real) * elementCnt_); memcpy(data_, src.getData(), sizeof(real) * elementCnt_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册