提交 7b060112 编写于 作者: L liuqi

Refactor caffe transform and validation code.

上级 759b5842
......@@ -7,6 +7,7 @@ embed_model_data: 1
vlog_level: 0
models:
preview_net:
platform: tensorflow
model_file_path: path/to/model64.pb # also support http:// and https://
input_node: input_node
output_node: output_node
......@@ -15,12 +16,16 @@ models:
runtime: gpu
limit_opencl_kernel_time: 0
dsp_mode: 0
obfuscate: 1
capture_net:
model_file_path: path/to/model256.pb
platform: caffe
model_file_path: path/to/model.prototxt
weight_file_path: path/to/weight.caffemodel
input_node: input_node
output_node: output_node
input_shape: 1,256,256,3
output_shape: 1,256,256,2
runtime: gpu
runtime: cpu
limit_opencl_kernel_time: 1
dsp_mode: 0
obfuscate: 1
......@@ -7,27 +7,22 @@ bazel build //lib/python/tools:converter || exit 1
rm -rf ${MODEL_CODEGEN_DIR}
mkdir -p ${MODEL_CODEGEN_DIR}
if [ ${DSP_MODE} ]; then
DSP_MODE_FLAG="--dsp_mode=${DSP_MODE}"
DSP_MODE_FLAG="--dsp_mode=${DSP_MODE}"
fi
OBFUSCATE=True
if [ "${BENCHMARK_FLAG}" = "1" ]; then
OBFUSCATE=False
fi
bazel-bin/lib/python/tools/tf_converter --platform=${PLATFORM} \
--model_file=${MODEL_FILE_PATH} \
--weight_file=${WEIGHT_FILE_PATH} \
--model_checksum=${MODEL_SHA256_CHECKSUM} \
--output=${MODEL_CODEGEN_DIR}/model.cc \
--input_node=${INPUT_NODE} \
--output_node=${OUTPUT_NODE} \
--data_type=${DATA_TYPE} \
--runtime=${RUNTIME} \
--output_type=source \
--template=${LIBMACE_SOURCE_DIR}/lib/python/tools/model.template \
--model_tag=${MODEL_TAG} \
--input_shape=${INPUT_SHAPE} \
${DSP_MODE_FLAG} \
--embed_model_data=${EMBED_MODEL_DATA} \
--obfuscate=${OBFUSCATE} || exit 1
bazel-bin/lib/python/tools/converter --platform=${PLATFORM} \
--model_file=${MODEL_FILE_PATH} \
--weight_file=${WEIGHT_FILE_PATH} \
--model_checksum=${MODEL_SHA256_CHECKSUM} \
--output=${MODEL_CODEGEN_DIR}/model.cc \
--input_node=${INPUT_NODE} \
--output_node=${OUTPUT_NODE} \
--data_type=${DATA_TYPE} \
--runtime=${RUNTIME} \
--output_type=source \
--template=${LIBMACE_SOURCE_DIR}/lib/python/tools/model.template \
--model_tag=${MODEL_TAG} \
--input_shape=${INPUT_SHAPE} \
${DSP_MODE_FLAG} \
--embed_model_data=${EMBED_MODEL_DATA} \
--obfuscate=${OBFUSCATE} || exit 1
......@@ -225,6 +225,11 @@ def main(unused_args):
os.environ["MODEL_FILE_PATH"] = model_output_dir + "/model.pb"
urllib.urlretrieve(model_config["model_file_path"], os.environ["MODEL_FILE_PATH"])
if model_config["platform"] == "caffe" and (model_config["weight_file_path"].startswith(
"http://") or model_config["weight_file_path"].startswith("https://")):
os.environ["WEIGHT_FILE_PATH"] = model_output_dir + "/model.caffemodel"
urllib.urlretrieve(model_config["weight_file_path"], os.environ["WEIGHT_FILE_PATH"])
if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all":
generate_random_input(model_output_dir)
......
......@@ -56,7 +56,7 @@ def valid_output(out_shape, mace_out_file, tf_out_value):
def run_model(input_shape):
if not gfile.Exists(FLAGS.model_file):
print("Input graph file '" + FLAGS.model_file + "' does not exist!")
return -1
sys.exit(-1)
input_graph_def = tf.GraphDef()
with gfile.Open(FLAGS.model_file, "rb") as f:
......
......@@ -59,10 +59,10 @@ def valid_output(out_shape, mace_out_file, out_value):
def run_model(input_shape):
if not os.path.isfile(FLAGS.model_file):
print("Input graph file '" + FLAGS.model_file + "' does not exist!")
return -1
sys.exit(-1)
if not os.path.isfile(FLAGS.weight_file):
print("Input weight file '" + FLAGS.weight_file + "' does not exist!")
return -1
sys.exit(-1)
caffe.set_mode_cpu()
......
......@@ -24,7 +24,7 @@ if [ "$GENERATE_DATA_OR_NOT" = 1 ]; then
exit 0
fi
if [ "$PLATFORM" = "tensorflow" ];then
if [ "$PLATFORM" == "tensorflow" ];then
rm -rf ${MODEL_OUTPUT_DIR}/${OUTPUT_FILE_NAME}
adb </dev/null pull ${PHONE_DATA_DIR}/${OUTPUT_FILE_NAME} ${MODEL_OUTPUT_DIR}
......@@ -37,7 +37,7 @@ if [ "$PLATFORM" = "tensorflow" ];then
--input_shape ${INPUT_SHAPE} \
--output_shape ${OUTPUT_SHAPE} || exit 1
elif [ "$PLATFORM" = "caffe" ];then
elif [ "$PLATFORM" == "caffe" ];then
IMAGE_NAME=mace-caffe:latest
CONTAINER_NAME=mace_caffe_validator
RES_FILE=validation.result
......@@ -60,13 +60,15 @@ elif [ "$PLATFORM" = "caffe" ];then
rm -rf ${MODEL_OUTPUT_DIR}/${OUTPUT_FILE_NAME}
adb </dev/null pull ${PHONE_DATA_DIR}/${OUTPUT_FILE_NAME} ${MODEL_OUTPUT_DIR}
MODEL_FILE_NAME=$(basename ${MODEL_FILE_PATH})
WEIGHT_FILE_NAME=$(basename ${WEIGHT_FILE_PATH})
docker cp tools/validate_caffe.py ${CONTAINER_NAME}:/mace
docker cp ${MODEL_OUTPUT_DIR}/${INPUT_FILE_NAME} ${CONTAINER_NAME}:/mace
docker cp ${MODEL_OUTPUT_DIR}/${OUTPUT_FILE_NAME} ${CONTAINER_NAME}:/mace
docker cp ${MODEL_FILE_PATH} ${CONTAINER_NAME}:/mace
docker cp ${WEIGHT_FILE_PATH} ${CONTAINER_NAME}:/mace
docker exec -it ${CONTAINER_NAME} python /mace/validate_caffe.py --model_file /mace/${MODEL_NAME} \
--weight_file /mace/${WEIGHT_NAME} \
docker exec -it ${CONTAINER_NAME} python /mace/validate_caffe.py --model_file /mace/${MODEL_FILE_NAME} \
--weight_file /mace/${WEIGHT_FILE_NAME} \
--input_file /mace/${INPUT_FILE_NAME} \
--mace_out_file /mace/${OUTPUT_FILE_NAME} \
--mace_runtime ${RUNTIME} \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册