diff --git a/example.yaml b/example.yaml index 4c4b3d17c3731c1cd69ac1f0a099d70d7908c922..c3358c40ce07231ef25009c3992f558dd09c31a2 100644 --- a/example.yaml +++ b/example.yaml @@ -5,18 +5,18 @@ target_abi: armeabi-v7a # arm64-v8a embed_model_data: 1 models: preview_net: - tf_model_file_path: path/to/model64.pb - tf_input_node: input - tf_output_node: softmax/Reshape_1 + model_file_path: path/to/model64.pb + input_node: input + output_node: softmax/Reshape_1 input_shape: 1,64,64,3 output_shape: 1,64,64,2 runtime: gpu limit_opencl_kernel_time: 0 dsp_mode: 0 capture_net: - tf_model_file_path: path/to/model256.pb - tf_input_node: input_node - tf_output_node: softmax/Reshape_1 + model_file_path: path/to/model256.pb + input_node: input_node + output_node: softmax/Reshape_1 input_shape: 1,256,256,3 output_shape: 1,256,256,2 runtime: gpu diff --git a/generate_model_code.sh b/generate_model_code.sh index 8184b871d347067b5f8a2e58a18ce7301c54481d..6247cdc4ee9f0a20ff4015c2d6e6b6a07026e926 100644 --- a/generate_model_code.sh +++ b/generate_model_code.sh @@ -15,10 +15,10 @@ if [ "${BENCHMARK_FLAG}" = "1" ]; then OBFUSCATE=False fi -bazel-bin/lib/python/tools/tf_converter --input=${TF_MODEL_FILE_PATH} \ +bazel-bin/lib/python/tools/tf_converter --input=${MODEL_FILE_PATH} \ --output=${MODEL_CODEGEN_DIR}/model.cc \ - --input_node=${TF_INPUT_NODE} \ - --output_node=${TF_OUTPUT_NODE} \ + --input_node=${INPUT_NODE} \ + --output_node=${OUTPUT_NODE} \ --data_type=${DATA_TYPE} \ --runtime=${RUNTIME} \ --output_type=source \ diff --git a/mace_tools.py b/mace_tools.py index e4b3a1b09222f674d9e556a3fbbbd28f95e4b31c..d03ac3263961a114f830873bbee20688020ccf94 100644 --- a/mace_tools.py +++ b/mace_tools.py @@ -2,7 +2,7 @@ # Must run at root dir of libmace project. # python tools/mace_tools.py \ -# --config=models/config \ +# --config=tools/example.yaml \ # --round=100 \ # --mode=all @@ -15,8 +15,6 @@ import yaml from ConfigParser import ConfigParser -tf_model_file_dir_key = "TF_MODEL_FILE_DIR" - def run_command(command): print("Run command: {}".format(command)) @@ -204,7 +202,7 @@ def main(unused_args): os.environ[key.upper()] = str(model_config[key]) model_output_dir = FLAGS.output_dir + "/" + target_abi + "/" + os.path.splitext( - model_config["tf_model_file_path"])[0] + model_config["model_file_path"])[0] model_output_dirs.append(model_output_dir) if FLAGS.mode == "build" or FLAGS.mode == "all": diff --git a/validate_tools.sh b/validate_tools.sh index 75f93a801be942b5196ce4a720cba4210ebc5901..17b30a60f245e37d1cee2390b51c09811d474b38 100644 --- a/validate_tools.sh +++ b/validate_tools.sh @@ -23,12 +23,12 @@ if [ "$GENERATE_DATA_OR_NOT" = 1 ]; then else rm -rf ${MODEL_OUTPUT_DIR}/${OUTPUT_FILE_NAME} adb