diff --git a/example.yaml b/example.yaml index 05a69a172842ab10fd1e0a26c680b45f4925f27e..911aa61a3b5c6be9966e88ab6bf51866e7e57f70 100644 --- a/example.yaml +++ b/example.yaml @@ -7,6 +7,7 @@ embed_model_data: 1 vlog_level: 0 models: preview_net: + platform: tensorflow model_file_path: path/to/model64.pb # also support http:// and https:// input_node: input_node output_node: output_node @@ -15,12 +16,16 @@ models: runtime: gpu limit_opencl_kernel_time: 0 dsp_mode: 0 + obfuscate: 1 capture_net: - model_file_path: path/to/model256.pb + platform: caffe + model_file_path: path/to/model.prototxt + weight_file_path: path/to/weight.caffemodel input_node: input_node output_node: output_node input_shape: 1,256,256,3 output_shape: 1,256,256,2 - runtime: gpu + runtime: cpu limit_opencl_kernel_time: 1 dsp_mode: 0 + obfuscate: 1 diff --git a/generate_model_code.sh b/generate_model_code.sh index d9c3b9c41406758191c18d1d0dd9283c61f2aeb9..0a780eddd6b12f9ebb6afc028a500e7d70c0030d 100644 --- a/generate_model_code.sh +++ b/generate_model_code.sh @@ -7,27 +7,22 @@ bazel build //lib/python/tools:converter || exit 1 rm -rf ${MODEL_CODEGEN_DIR} mkdir -p ${MODEL_CODEGEN_DIR} if [ ${DSP_MODE} ]; then - DSP_MODE_FLAG="--dsp_mode=${DSP_MODE}" + DSP_MODE_FLAG="--dsp_mode=${DSP_MODE}" fi -OBFUSCATE=True -if [ "${BENCHMARK_FLAG}" = "1" ]; then - OBFUSCATE=False -fi - -bazel-bin/lib/python/tools/tf_converter --platform=${PLATFORM} \ - --model_file=${MODEL_FILE_PATH} \ - --weight_file=${WEIGHT_FILE_PATH} \ - --model_checksum=${MODEL_SHA256_CHECKSUM} \ - --output=${MODEL_CODEGEN_DIR}/model.cc \ - --input_node=${INPUT_NODE} \ - --output_node=${OUTPUT_NODE} \ - --data_type=${DATA_TYPE} \ - --runtime=${RUNTIME} \ - --output_type=source \ - --template=${LIBMACE_SOURCE_DIR}/lib/python/tools/model.template \ - --model_tag=${MODEL_TAG} \ - --input_shape=${INPUT_SHAPE} \ - ${DSP_MODE_FLAG} \ - --embed_model_data=${EMBED_MODEL_DATA} \ - --obfuscate=${OBFUSCATE} || exit 1 +bazel-bin/lib/python/tools/converter --platform=${PLATFORM} \ + --model_file=${MODEL_FILE_PATH} \ + --weight_file=${WEIGHT_FILE_PATH} \ + --model_checksum=${MODEL_SHA256_CHECKSUM} \ + --output=${MODEL_CODEGEN_DIR}/model.cc \ + --input_node=${INPUT_NODE} \ + --output_node=${OUTPUT_NODE} \ + --data_type=${DATA_TYPE} \ + --runtime=${RUNTIME} \ + --output_type=source \ + --template=${LIBMACE_SOURCE_DIR}/lib/python/tools/model.template \ + --model_tag=${MODEL_TAG} \ + --input_shape=${INPUT_SHAPE} \ + ${DSP_MODE_FLAG} \ + --embed_model_data=${EMBED_MODEL_DATA} \ + --obfuscate=${OBFUSCATE} || exit 1 diff --git a/mace_tools.py b/mace_tools.py index 462fde03c44b8d375fac452085c1c1e21dde51b8..11143a607c5bb78fa1fa448dc147ee21409704cc 100644 --- a/mace_tools.py +++ b/mace_tools.py @@ -225,6 +225,11 @@ def main(unused_args): os.environ["MODEL_FILE_PATH"] = model_output_dir + "/model.pb" urllib.urlretrieve(model_config["model_file_path"], os.environ["MODEL_FILE_PATH"]) + if model_config["platform"] == "caffe" and (model_config["weight_file_path"].startswith( + "http://") or model_config["weight_file_path"].startswith("https://")): + os.environ["WEIGHT_FILE_PATH"] = model_output_dir + "/model.caffemodel" + urllib.urlretrieve(model_config["weight_file_path"], os.environ["WEIGHT_FILE_PATH"]) + if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all": generate_random_input(model_output_dir) diff --git a/validate.py b/validate.py index 5c66efe31dcf01ba8aad5cb90eb3b84a4da0adad..62401489a5172af9e44576bfbb855dd44ebb38a0 100644 --- a/validate.py +++ b/validate.py @@ -56,7 +56,7 @@ def valid_output(out_shape, mace_out_file, tf_out_value): def run_model(input_shape): if not gfile.Exists(FLAGS.model_file): print("Input graph file '" + FLAGS.model_file + "' does not exist!") - return -1 + sys.exit(-1) input_graph_def = tf.GraphDef() with gfile.Open(FLAGS.model_file, "rb") as f: diff --git a/validate_caffe.py b/validate_caffe.py index e33a9e31974a5cb5d2bb895f6fb8ecbfdd29a488..cc50242ffe76fdb2f5a20502f635d6bceb481d1a 100644 --- a/validate_caffe.py +++ b/validate_caffe.py @@ -59,10 +59,10 @@ def valid_output(out_shape, mace_out_file, out_value): def run_model(input_shape): if not os.path.isfile(FLAGS.model_file): print("Input graph file '" + FLAGS.model_file + "' does not exist!") - return -1 + sys.exit(-1) if not os.path.isfile(FLAGS.weight_file): print("Input weight file '" + FLAGS.weight_file + "' does not exist!") - return -1 + sys.exit(-1) caffe.set_mode_cpu() diff --git a/validate_tools.sh b/validate_tools.sh index 7df50c477fb2ec2e51fc77cbacd18b606ebddab3..90e8f2eb0f2dd573a41f47fe2b32cf5a1c319085 100644 --- a/validate_tools.sh +++ b/validate_tools.sh @@ -24,7 +24,7 @@ if [ "$GENERATE_DATA_OR_NOT" = 1 ]; then exit 0 fi -if [ "$PLATFORM" = "tensorflow" ];then +if [ "$PLATFORM" == "tensorflow" ];then rm -rf ${MODEL_OUTPUT_DIR}/${OUTPUT_FILE_NAME} adb