提交 65e1a0a8 编写于 作者: W wuzewu

add demo script

上级 de93df86
#!/bin/bash
set -o nounset
set -o errexit
script_path=$(cd `dirname $0`; pwd)
cd $script_path
model_name="ResNet50"
hub_module_save_dir="./hub_module"
while getopts "m:d:" options
do
case "$options" in
d)
hub_module_save_dir=$OPTARG;;
m)
model_name=$OPTARG;;
?)
echo "unknown options"
exit 1;;
esac
done
sh pretraind_models/download_model.sh ${model_name}
python train.py --create_module=True --pretrained_model=pretraind_models/${model_name} --model ${model_name} --use_gpu=False
......@@ -23,6 +23,7 @@ from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('create_module', bool, False, "create a hub module or not" )
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('total_images', int, 12000, "Training image number.")
......@@ -201,6 +202,19 @@ def train(args):
fluid.io.load_vars(
exe, pretrained_model, main_program=train_prog, predicate=if_exist)
if args.create_module:
assert pretrained_model, "need a pretrained module to create a hub module"
sign1 = hub.create_signature(
"classification", inputs=[image], outputs=[predition])
sign2 = hub.create_signature(
"feature_map", inputs=[image], outputs=[feature_map])
sign3 = hub.create_signature(inputs=[image], outputs=[predition])
hub.create_module(
sign_arr=[sign1, sign2, sign3],
program=train_prog,
module_dir="hub_module_" + args.model)
exit()
visible_device = os.getenv('CUDA_VISIBLE_DEVICES')
if visible_device:
device_num = len(visible_device.split(','))
......@@ -296,16 +310,6 @@ def train(args):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path, main_program=train_prog)
sign1 = hub.create_signature(
"classification", inputs=[image], outputs=[predition])
sign2 = hub.create_signature(
"feature_map", inputs=[image], outputs=[feature_map])
sign3 = hub.create_signature(inputs=[image], outputs=[predition])
hub.create_module(
sign_arr=[sign1, sign2, sign3],
program=train_prog,
module_dir="hub_module" + args.model)
def main():
args = parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册