From 7b060112568d900c5f2261aecd6fba21bbb45a85 Mon Sep 17 00:00:00 2001 From: liuqi Date: Mon, 5 Mar 2018 16:04:49 +0800 Subject: [PATCH] Refactor caffe transform and validation code. --- example.yaml | 9 +++++++-- generate_model_code.sh | 39 +++++++++++++++++---------------------- mace_tools.py | 5 +++++ validate.py | 2 +- validate_caffe.py | 4 ++-- validate_tools.sh | 10 ++++++---- 6 files changed, 38 insertions(+), 31 deletions(-) diff --git a/example.yaml b/example.yaml index 05a69a17..911aa61a 100644 --- a/example.yaml +++ b/example.yaml @@ -7,6 +7,7 @@ embed_model_data: 1 vlog_level: 0 models: preview_net: + platform: tensorflow model_file_path: path/to/model64.pb # also support http:// and https:// input_node: input_node output_node: output_node @@ -15,12 +16,16 @@ models: runtime: gpu limit_opencl_kernel_time: 0 dsp_mode: 0 + obfuscate: 1 capture_net: - model_file_path: path/to/model256.pb + platform: caffe + model_file_path: path/to/model.prototxt + weight_file_path: path/to/weight.caffemodel input_node: input_node output_node: output_node input_shape: 1,256,256,3 output_shape: 1,256,256,2 - runtime: gpu + runtime: cpu limit_opencl_kernel_time: 1 dsp_mode: 0 + obfuscate: 1 diff --git a/generate_model_code.sh b/generate_model_code.sh index d9c3b9c4..0a780edd 100644 --- a/generate_model_code.sh +++ b/generate_model_code.sh @@ -7,27 +7,22 @@ bazel build //lib/python/tools:converter || exit 1 rm -rf ${MODEL_CODEGEN_DIR} mkdir -p ${MODEL_CODEGEN_DIR} if [ ${DSP_MODE} ]; then - DSP_MODE_FLAG="--dsp_mode=${DSP_MODE}" + DSP_MODE_FLAG="--dsp_mode=${DSP_MODE}" fi -OBFUSCATE=True -if [ "${BENCHMARK_FLAG}" = "1" ]; then - OBFUSCATE=False -fi - -bazel-bin/lib/python/tools/tf_converter --platform=${PLATFORM} \ - --model_file=${MODEL_FILE_PATH} \ - --weight_file=${WEIGHT_FILE_PATH} \ - --model_checksum=${MODEL_SHA256_CHECKSUM} \ - --output=${MODEL_CODEGEN_DIR}/model.cc \ - --input_node=${INPUT_NODE} \ - --output_node=${OUTPUT_NODE} \ - --data_type=${DATA_TYPE} \ - --runtime=${RUNTIME} \ - --output_type=source \ - --template=${LIBMACE_SOURCE_DIR}/lib/python/tools/model.template \ - --model_tag=${MODEL_TAG} \ - --input_shape=${INPUT_SHAPE} \ - ${DSP_MODE_FLAG} \ - --embed_model_data=${EMBED_MODEL_DATA} \ - --obfuscate=${OBFUSCATE} || exit 1 +bazel-bin/lib/python/tools/converter --platform=${PLATFORM} \ + --model_file=${MODEL_FILE_PATH} \ + --weight_file=${WEIGHT_FILE_PATH} \ + --model_checksum=${MODEL_SHA256_CHECKSUM} \ + --output=${MODEL_CODEGEN_DIR}/model.cc \ + --input_node=${INPUT_NODE} \ + --output_node=${OUTPUT_NODE} \ + --data_type=${DATA_TYPE} \ + --runtime=${RUNTIME} \ + --output_type=source \ + --template=${LIBMACE_SOURCE_DIR}/lib/python/tools/model.template \ + --model_tag=${MODEL_TAG} \ + --input_shape=${INPUT_SHAPE} \ + ${DSP_MODE_FLAG} \ + --embed_model_data=${EMBED_MODEL_DATA} \ + --obfuscate=${OBFUSCATE} || exit 1 diff --git a/mace_tools.py b/mace_tools.py index 462fde03..11143a60 100644 --- a/mace_tools.py +++ b/mace_tools.py @@ -225,6 +225,11 @@ def main(unused_args): os.environ["MODEL_FILE_PATH"] = model_output_dir + "/model.pb" urllib.urlretrieve(model_config["model_file_path"], os.environ["MODEL_FILE_PATH"]) + if model_config["platform"] == "caffe" and (model_config["weight_file_path"].startswith( + "http://") or model_config["weight_file_path"].startswith("https://")): + os.environ["WEIGHT_FILE_PATH"] = model_output_dir + "/model.caffemodel" + urllib.urlretrieve(model_config["weight_file_path"], os.environ["WEIGHT_FILE_PATH"]) + if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all": generate_random_input(model_output_dir) diff --git a/validate.py b/validate.py index 5c66efe3..62401489 100644 --- a/validate.py +++ b/validate.py @@ -56,7 +56,7 @@ def valid_output(out_shape, mace_out_file, tf_out_value): def run_model(input_shape): if not gfile.Exists(FLAGS.model_file): print("Input graph file '" + FLAGS.model_file + "' does not exist!") - return -1 + sys.exit(-1) input_graph_def = tf.GraphDef() with gfile.Open(FLAGS.model_file, "rb") as f: diff --git a/validate_caffe.py b/validate_caffe.py index e33a9e31..cc50242f 100644 --- a/validate_caffe.py +++ b/validate_caffe.py @@ -59,10 +59,10 @@ def valid_output(out_shape, mace_out_file, out_value): def run_model(input_shape): if not os.path.isfile(FLAGS.model_file): print("Input graph file '" + FLAGS.model_file + "' does not exist!") - return -1 + sys.exit(-1) if not os.path.isfile(FLAGS.weight_file): print("Input weight file '" + FLAGS.weight_file + "' does not exist!") - return -1 + sys.exit(-1) caffe.set_mode_cpu() diff --git a/validate_tools.sh b/validate_tools.sh index 7df50c47..90e8f2eb 100644 --- a/validate_tools.sh +++ b/validate_tools.sh @@ -24,7 +24,7 @@ if [ "$GENERATE_DATA_OR_NOT" = 1 ]; then exit 0 fi -if [ "$PLATFORM" = "tensorflow" ];then +if [ "$PLATFORM" == "tensorflow" ];then rm -rf ${MODEL_OUTPUT_DIR}/${OUTPUT_FILE_NAME} adb