diff --git a/tools/python/fluidtools/.gitignore b/tools/python/fluidtools/.gitignore index 4a8d0b3da085aa508c0c7f810c0972e728049812..100524f4410784c9aae254b1875e5946d0a9f829 100644 --- a/tools/python/fluidtools/.gitignore +++ b/tools/python/fluidtools/.gitignore @@ -1,4 +1,4 @@ * !run.py !.gitignore -info.txt +!/model-encrypt-tool diff --git a/tools/python/fluidtools/model-encrypt-tool/enc_key_gen b/tools/python/fluidtools/model-encrypt-tool/enc_key_gen new file mode 100755 index 0000000000000000000000000000000000000000..a1cc223847d0f7ca482869e63910743a13e95f5b Binary files /dev/null and b/tools/python/fluidtools/model-encrypt-tool/enc_key_gen differ diff --git a/tools/python/fluidtools/model-encrypt-tool/enc_model_gen b/tools/python/fluidtools/model-encrypt-tool/enc_model_gen new file mode 100755 index 0000000000000000000000000000000000000000..6609434b589fd80ddc15e1503d07bad5e5260f6b Binary files /dev/null and b/tools/python/fluidtools/model-encrypt-tool/enc_model_gen differ diff --git a/tools/python/fluidtools/run.py b/tools/python/fluidtools/run.py index 19507d361168a43195a6059e2dd4de0fdad0eb50..fcbdc8d1e7f4ba7cb71eda83e5da4558db2508b3 100644 --- a/tools/python/fluidtools/run.py +++ b/tools/python/fluidtools/run.py @@ -17,6 +17,8 @@ fast_check = False is_sample_step = False sample_step = 1 sample_num = 20 +need_encrypt = False +checked_encrypt_model_path = "checked_encrypt_model" np.set_printoptions(linewidth=150) @@ -107,6 +109,27 @@ def resave_model(feed_kv): pp_green("has not found wrong shape", 1) pp_green("new model is saved into directory 【{}】".format(checked_model_path), 1) +# 分别加密model和params,加密key使用同一个 +def encrypt_model(): + if not need_encrypt: + return + pp_yellow(dot + dot + " encrypting model") + if not os.path.exists(checked_encrypt_model_path): + os.mkdir(checked_encrypt_model_path) + res = sh("model-encrypt-tool/enc_key_gen -l 20 -c 232") + lines = res.split("\n") + + for line in lines: + if line.startswith("key:"): + line = line.replace('key:','') + sh("model-encrypt-tool/enc_model_gen -k '{}' -c 2 -i checked_model/model -o " + "checked_model/model.ml".format(line)) + sh("model-encrypt-tool/enc_model_gen -k '{}' -c 2 -i checked_model/params -o checked_model/params.ml".format(line)) + pp_green("model has been encrypted, key is : {}".format(line), 1) + sh("mv {} {}".format(checked_model_path + "/*.ml", checked_encrypt_model_path)) + return + pp_red("model encrypt error", 1) + # 生成feed的key-value对 def gen_feed_kv(): feed_kv = {} @@ -413,6 +436,8 @@ def main(): # 重新保存模型 pp_yellow(dot + dot + " checking model correctness") resave_model(feed_kv=feed_kv) + # 输出加密模型 + encrypt_model() # 输出所有中间结果 pp_yellow(dot + dot + " checking output result of every op") save_all_op_output(feed_kv=feed_kv)