“761b3297934c741aaa4d13130a3e0a31131506d5”上不存在“paddle/fluid/operators/math/sequence_scale.cc”
提交 82b2f814 编写于 作者: Z zhongjiafeng

paddle model encrypt & decrypt support with linux and windows

上级 d9d547dc
cmake_minimum_required(VERSION 3.12)
project(paddle_model_protect)
set(CMAKE_CXX_STANDARD 11)
IF (CMAKE_SYSTEM_NAME MATCHES "Windows")
option(PM_EXPORTS "export symbols in windows" ON)
IF (PM_EXPORTS)
message("add_definitions of PM_EXPORTS")
add_definitions("-DPM_EXPORTS")
ENDIF ()
ENDIF ()
IF (CMAKE_SYSTEM_NAME MATCHES "Linux")
# use "-fvisibility=hidden" instead of "-Wl,--version-script ${CMAKE_CURRENT_SOURCE_DIR}/export_rule.map"
set(CMAKE_C_VISIBILITY_PRESET hidden)
set(CMAKE_CXX_VISIBILITY_PRESET hidden)
set(CMAKE_C_FLAGS "-g -O2 -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-g -O2 -fPIC ${CMAKE_CXX_FLAGS}")
ELSEIF (CMAKE_SYSTEM_NAME MATCHES "Windows")
set(CMAKE_C_FLAGS_RELEASE "/MT")
set(CMAKE_CXX_FLAGS_RELEASE "/MT")
ENDIF ()
SET(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/output/bin)
SET(LIBRARY_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/output/lib)
file(COPY "${PROJECT_SOURCE_DIR}/include/paddle_model_encrypt.h" DESTINATION "${PROJECT_SOURCE_DIR}/output/include")
file(COPY "${PROJECT_SOURCE_DIR}/include/paddle_model_decrypt.h" DESTINATION "${PROJECT_SOURCE_DIR}/output/include")
file(COPY "${PROJECT_SOURCE_DIR}/include/model_code.h" DESTINATION "${PROJECT_SOURCE_DIR}/output/include")
set(SRC_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/src")
set(OPENSSL_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/3rd/openssl-1.1.0k/install-${CMAKE_SYSTEM_PROCESSOR}")
set(OPENSSL_INCLUDE "${OPENSSL_ROOT_DIR}/include")
IF (CMAKE_SYSTEM_NAME MATCHES "Windows")
set(OPENSSL_LIBS
"${OPENSSL_ROOT_DIR}/lib/libssl_static.lib"
"${OPENSSL_ROOT_DIR}/lib/libcrypto_static.lib")
ELSEIF (CMAKE_SYSTEM_NAME MATCHES "Linux")
set(OPENSSL_LIBS
"${OPENSSL_ROOT_DIR}/lib/libssl.a"
"${OPENSSL_ROOT_DIR}/lib/libcrypto.a")
ENDIF ()
set(PADDLE_INCLUDE_DIR "${PADDLE_DIR}/include")
IF (CMAKE_SYSTEM_NAME MATCHES "Windows")
# -DPADDLE_DIR=C:\developer\Paddle-developer\Paddle\build\fluid_inference_install_dir\paddle
set(PADDLE_LIBS "${PADDLE_DIR}/lib/paddle_fluid.lib")
ENDIF ()
include_directories(
include
${OPENSSL_INCLUDE}
${PADDLE_INCLUDE_DIR}
)
IF (MSVC)
# Visual Studio 2015
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ws2_32.lib /NODEFAULTLIB:libcmt.lib")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ws2_32.lib /NODEFAULTLIB:libcmt.lib")
ENDIF ()
add_subdirectory(${SRC_ROOT_DIR})
一、Linux
在 centos 7 上GCC 4.8.5 编译通过
Step1: 编译
编译 cmake 的命令在 build.sh 中,请根据实际情况修改主要参数PADDLE_DIR的路径
修改脚本设置好参数后,执行build脚本
sh build.sh
Step2: 产出在output目录
2.1 头文件
include/model_code.h
include/paddle_model_encrypt.h
include/paddle_model_decrypt.h
2.2 编译产出库
lib/libpmodel-encrypt.so
lib/libpmodel-decrypt.so
2.3 执行工具(使用-h参数查看)
bin/paddle_encrypt_tool
二、Windows
在windows 10 Visual Studio 14 2015 上编译通过
Step1: 编译
修改 build.bat 中 PADDLE_DIR 的路径
执行 build.bat 脚本
Step2:打开 blend Visual Studio 2015,
选择 open project -> 找到 Step1 中生成的 paddle—model-protect.sln -> 选择 Release 和 x64 -> ALL BUILD -> 右键生成
@echo off
set PADDLE_DIR=/path/to/Paddle/include
set workPath=%~dp0
set thirdPartyPath=%~dp03rd
if exist %thirdPartyPath% (
echo %thirdPartyPath% exist
rd /S /Q %thirdPartyPath%
)
echo createDir %thirdPartyPath%
md %thirdPartyPath%
cd %thirdPartyPath%
wget --no-check-certificate https://bj.bcebos.com/paddlex/tools/openssl-1.1.0k.tar.gz
tar -zxvf openssl-1.1.0k.tar.gz
del openssl-1.1.0k.tar.gz
cd %workPath%
if exist %workPath%build (
rd /S /Q %workPath%build
)
if exist %workPath%\output (
rd /S /Q %workPath%\output
)
MD %workPath%build
MD %workPath%\output
cd %workPath%build
cmake .. -G "Visual Studio 14 2015" -A x64 -T host=x64 -DCMAKE_BUILD_TYPE=Release -DPADDLE_DIR=%PADDLE_DIR%
cd %workPath%
PADDLE_DIR=/home/parallels/developers/paddleX-PR/paddle
if [ ! -d "3rd" ]; then
mkdir 3rd
fi
cd 3rd
wget https://bj.bcebos.com/paddlex/tools/openssl-1.1.0k.tar.gz
tar -zxvf openssl-1.1.0k.tar.gz
rm openssl-1.1.0k.tar.gz
cd ..
rm -rf build output
mkdir build && cd build
cmake .. \
-DPADDLE_DIR=${PADDLE_DIR}
make
@echo off
set workPath=%~dp0
set thirdPartyPath=%~dp03rd
if exist %thirdPartyPath% (
echo %thirdPartyPath% exist
rd /S /Q %thirdPartyPath%
)
cd %workPath%
if exist %workPath%build (
rd /S /Q %workPath%build
)
if exist %workPath%\output (
rd /S /Q %workPath%\output
)
if [ -d "3rd" ]; then
rm -rf 3rd
fi
if [ -d "build" ]; then
rm -rf build
fi
if [ -d "output" ]; then
rm -rf output
fi
{
global:
paddle_generate_random_key;
paddle_encrypt_model;
paddle_security_load_model;
paddle_check_file_encrypted;
paddle_encrypt_dir;
local:
*;
};
\ No newline at end of file
#ifndef PADDLE_MODEL_PROTECT_MODEL_CODE_H
#define PADDLE_MODEL_PROTECT_MODEL_CODE_H
#ifdef __cplusplus
extern "C" {
#endif
enum {
CODE_OK = 0,
CODE_OPEN_FAILED = 100,
CODE_READ_FILE_PTR_IS_NULL = 101,
CODE_AES_GCM_ENCRYPT_FIALED = 102,
CODE_AES_GCM_DECRYPT_FIALED = 103,
CODE_KEY_NOT_MATCH = 104,
CODE_KEY_LENGTH_ABNORMAL = 105,
CODE_NOT_EXIST_DIR = 106,
CODE_FILES_EMPTY_WITH_DIR = 107,
CODE_MODEL_FILE_NOT_EXIST = 108,
CODE_PARAMS_FILE_NOT_EXIST = 109,
CODE_MODEL_YML_FILE_NOT_EXIST = 110,
CODE_MKDIR_FAILED = 111
};
#ifdef __cplusplus
}
#endif
#endif //PADDLE_MODEL_PROTECT_MODEL_CODE_H
#pragma once
#include <stdio.h>
#include "paddle_inference_api.h"
#ifndef PADDLE_MODEL_PROTECT_API_PADDLE_MODEL_DECRYPT_H
#define PADDLE_MODEL_PROTECT_API_PADDLE_MODEL_DECRYPT_H
#ifdef WIN32
#ifdef PM_EXPORTS
#define PDD_MODEL_API __declspec(dllexport)
#else
#define PDD_MODEL_API __declspec(dllimport)
#endif
#endif
#ifdef linux
#define PDD_MODEL_API __attribute__((visibility("default")))
#endif
#ifdef __cplusplus
extern "C" {
#endif
/**
* load (un)encrypted model and params to paddle::AnalysisConfig
* @param config
* @param key 加解密key(注:该SDK能符合的key信息为32字节转为BASE64编码后才能通过)
* @param model_file 模型文件路径
* @param param_file 参数文件路径
* @return error_code
*/
PDD_MODEL_API int paddle_security_load_model(paddle::AnalysisConfig* config,
const char* key,
const char* model_file,
const char* param_file);
/**
* check file (un)encrypted?
* @param file_path
* @return
*/
PDD_MODEL_API int paddle_check_file_encrypted(const char* file_path);
PDD_MODEL_API std::string decrypt_file(const char* file_path, const char* key);
#ifdef __cplusplus
}
#endif
#endif //PADDLE_MODEL_PROTECT_API_PADDLE_MODEL_DECRYPT_H
#pragma once
#include <iostream>
#ifndef PADDLE_MODEL_PROTECT_API_PADDLE_MODEL_ENCRYPT_H
#define PADDLE_MODEL_PROTECT_API_PADDLE_MODEL_ENCRYPT_H
#ifdef WIN32
#ifdef PM_EXPORTS
#define PDE_MODEL_API __declspec(dllexport)
#else
#define PDE_MODEL_API __declspec(dllimport)
#endif
#endif
#ifdef linux
#define PDE_MODEL_API __attribute__((visibility("default")))
#endif
#ifdef __cplusplus
extern "C" {
#endif
/**
* generate random key
* 产生随机的 key 信息,如果想要使用当前 SDK,
* 对于传入的key信息有要求(需符合产生32字节随机值后做 BASE64 编码
* @return
*/
PDE_MODEL_API std::string paddle_generate_random_key();
/**
* encrypt __model__, __params__ files in src_dir to dst_dir
* @param keydata
* @param src_dir
* @param dst_dir
* @return
*/
PDE_MODEL_API int paddle_encrypt_dir(const char* keydata, const char* src_dir, const char* dst_dir);
/**
* encrypt file
* @param keydata 可使用由 paddle_generate_random_key 接口产生的key,也可以根据规则自己生成
* @param infile
* @param outfile
* @return error_code
*/
PDE_MODEL_API int paddle_encrypt_model(const char* keydata, const char* infile, const char* outfile);
#ifdef __cplusplus
}
#endif
#endif //PADDLE_MODEL_PROTECT_API_PADDLE_MODEL_ENCRYPT_H
#include <iostream>
#include <cstring>
#include "model_code.h"
#include "paddle_model_encrypt.h"
#include "paddle_inference_api.h"
#ifdef linux
#define RESET "\033[0m"
#define BOLD "\033[1m"
#define BOLDGREEN "\033[1m\033[32m"
#elif WIN32
#define RESET ""
#define BOLD ""
#define BOLDGREEN ""
#endif
void help() {
std::cout << BOLD << "*** paddle_encrypt_tool Usage ***" << RESET << std::endl;
std::cout << "[1]Help:" << std::endl;
std::cout << "\t-h" << std::endl;
std::cout << "[2]Generate random key and encrypt dir files" << std::endl;
std::cout << "\t-model_dir\tmodel_dir_ori\t-save_dir\tencrypted_models" << std::endl;
std::cout << "[3]Generate random key for encrypt file" << std::endl;
std::cout << "\t-g" << std::endl;
std::cout << "[4]Encrypt file:" << std::endl;
std::cout << "\t-e\t-key\tkeydata\t-infile\tinfile\t-outfile\toutfile" << std::endl;
}
int main(int argc, char** argv) {
switch (argc) {
case 2:
if (strcmp(argv[1], "-g") == 0) {
std::cout << BOLD << "Generate key success: \n\t" << RESET << BOLDGREEN << paddle_generate_random_key()
<< RESET << std::endl;
} else {
help();
}
break;
case 5:
if (strcmp(argv[1], "-model_dir") == 0 && strcmp(argv[3], "-save_dir") == 0) {
std::string key_random = paddle_generate_random_key();
std::cout << BOLD << "Output: " << "Encryption key: \n\t" << RESET << BOLDGREEN
<< key_random << RESET << std::endl;
int ret = paddle_encrypt_dir(key_random.c_str(), argv[2], argv[4]);
switch (ret) {
case CODE_OK:
std::cout << "Success, Encrypt __model__, __params__ to " << argv[4] << "(dir) success!"
<< std::endl;
break;
case CODE_MODEL_FILE_NOT_EXIST:
std::cout << "Failed, errorcode = " << ret << ", could't find __model__(file) in " << argv[2]
<< std::endl;
break;
case CODE_MODEL_YML_FILE_NOT_EXIST:
std::cout << "Failed, errorcode = " << ret << ", could't find model.yml(file) in " << argv[2]
<< std::endl;
break;
case CODE_PARAMS_FILE_NOT_EXIST:
std::cout << "Failed, errorcode = " << ret << ", could't find __params__(file) in " << argv[2]
<< std::endl;
break;
case CODE_NOT_EXIST_DIR:
std::cout << "Failed, errorcode = " << ret << ", " << argv[2] << "(dir) not exist" << std::endl;
break;
case CODE_FILES_EMPTY_WITH_DIR:
std::cout << "Failed, errorcode = " << ret << ", could't find any files in " << argv[2]
<< std::endl;
break;
default:std::cout << "Failed, errorcode = " << ret << ", others" << std::endl;
break;
}
} else {
help();
}
break;
case 8:
if (strcmp(argv[1], "-e") == 0 && strcmp(argv[2], "-key") == 0 && strcmp(argv[4], "-infile") == 0
&& strcmp(argv[6], "-outfile") == 0) {
int ret_encrypt = paddle_encrypt_model(argv[3], argv[5], argv[7]);
if (ret_encrypt == 0) {
std::cout << "Encrypt " << argv[5] << "(file) to " << argv[7] << "(file) success" << std::endl;
} else {
std::cout << "Encrypt " << argv[5] << " failed, ret = " << ret_encrypt << std::endl;
}
} else {
help();
}
break;
default:help();
}
#ifdef WIN32
system("pause");
#endif
return 0;
}
\ No newline at end of file
set(SRC_COMMON
util/crypto/basic.cpp
util/system_utils.cpp
util/io_utils.cpp
util/crypto/aes_gcm.cpp
util/crypto/sha256_utils.cpp
util/crypto/base64.cpp)
set(SRC_ENCRYPT
safeapi/paddle_model_encrypt.cpp
${SRC_COMMON}
)
set(SRC_DECRYPT
safeapi/paddle_model_decrypt.cpp
${SRC_COMMON})
# encrypt: libpmodel-encrypt.so
add_library(pmodel-encrypt SHARED
${SRC_ENCRYPT})
IF (CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(pmodel-encrypt
${OPENSSL_LIBS}
)
ELSEIF (CMAKE_SYSTEM_NAME MATCHES "Linux")
target_link_libraries(pmodel-encrypt
${OPENSSL_LIBS}
-ldl -lpthread
)
ENDIF ()
# decrypt: libpmodel-decrypt.so
add_library(pmodel-decrypt SHARED
${SRC_DECRYPT})
IF (CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(pmodel-decrypt
${OPENSSL_LIBS}
${PADDLE_LIBS}
)
ELSEIF (CMAKE_SYSTEM_NAME MATCHES "Linux")
target_link_libraries(pmodel-decrypt
${OPENSSL_LIBS}
-ldl -lpthread
)
ENDIF ()
# tool: paddle_encrypt_tool
add_executable(paddle_encrypt_tool
../sample/paddle_encrypt_tool.cpp
safeapi/paddle_model_encrypt.cpp
util/crypto/basic.cpp
util/system_utils.cpp
util/io_utils.cpp
util/crypto/aes_gcm.cpp
util/crypto/sha256_utils.cpp
util/crypto/base64.cpp)
IF (CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(paddle_encrypt_tool
${OPENSSL_LIBS}
)
ELSEIF (CMAKE_SYSTEM_NAME MATCHES "Linux")
target_link_libraries(paddle_encrypt_tool
${OPENSSL_LIBS}
-ldl -lpthread
)
ENDIF ()
#ifndef PADDLE_MODEL_PROTECT_CONSTANT_CONSTANT_MODEL_H
#define PADDLE_MODEL_PROTECT_CONSTANT_CONSTANT_MODEL_H
namespace constant {
const static std::string MAGIC_NUMBER = "PADDLE";
const static std::string VERSION = "1";
const static int MAGIC_NUMBER_LEN = 6;
const static int VERSION_LEN = 1;
const static int TAG_LEN = 128;
}
#endif //PADDLE_MODEL_PROTECT_CONSTANT_CONSTANT_MODEL_H
#include <iostream>
#include <string>
#include <string.h>
#include "paddle_model_decrypt.h"
#include "model_code.h"
#include "../util/crypto/aes_gcm.h"
#include "../util/io_utils.h"
#include "../util/log.h"
#include "../constant/constant_model.h"
#include "../util/system_utils.h"
#include "../util/crypto/base64.h"
/**
* 0 - encrypted
* 1 - unencrypt
*/
int paddle_check_file_encrypted(const char* file_path) {
return util::SystemUtils::check_file_encrypted(file_path);
}
std::string decrypt_file(const char* file_path, const char* key) {
int ret = paddle_check_file_encrypted(file_path);
if (ret != CODE_OK) {
LOGD("[M]check file encrypted failed, code: %d", ret);
return std::string();
}
// std::string key_str = util::crypto::Base64Utils::decode(std::string(key));
std::string key_str = baidu::base::base64::base64_decode(std::string(key));
int ret_check = util::SystemUtils::check_key_match(key_str.c_str(), file_path);
if (ret_check != CODE_OK) {
LOGD("[M]check key failed in decrypt_file, code: %d", ret_check);
return std::string();
}
unsigned char* aes_key = (unsigned char*) malloc(sizeof(unsigned char) * AES_GCM_KEY_LENGTH);
unsigned char* aes_iv = (unsigned char*) malloc(sizeof(unsigned char) * AES_GCM_IV_LENGTH);
memcpy(aes_key, key_str.c_str(), AES_GCM_KEY_LENGTH);
memcpy(aes_iv, key_str.c_str() + 16, AES_GCM_IV_LENGTH);
size_t pos = constant::MAGIC_NUMBER_LEN + constant::VERSION_LEN + constant::TAG_LEN;
// read encrypted data
unsigned char* dataptr = NULL;
size_t data_len = 0;
int ret_read_data = ioutil::read_with_pos(file_path, pos, &dataptr, &data_len);
if (ret_read_data != CODE_OK) {
LOGD("[M]read file failed, code = %d", ret_read_data);
return std::string();
}
// decrypt model data
size_t model_plain_len = data_len - AES_GCM_TAG_LENGTH;
unsigned char* model_plain = (unsigned char*) malloc(sizeof(unsigned char) * model_plain_len);
int ret_decrypt_file =
util::crypto::AesGcm::decrypt_aes_gcm(
dataptr,
data_len,
aes_key,
aes_iv,
model_plain,
reinterpret_cast<int&>(model_plain_len));
free(dataptr);
free(aes_key);
free(aes_iv);
if (ret_decrypt_file != CODE_OK) {
free(model_plain);
LOGD("[M]decrypt file failed, decrypt ret = %d", ret_decrypt_file);
return std::string();
}
std::string result((const char*)model_plain);
free(model_plain);
return result;
}
/**
* support model_file encrypted or unencrypt
* support params_file encrypted or unencrypt
* all in one interface
*/
int paddle_security_load_model(
paddle::AnalysisConfig* config,
const char* key,
const char* model_file,
const char* param_file) {
// 0 - file encrypted 1 - file unencrypted
int m_en_flag = util::SystemUtils::check_file_encrypted(model_file);
if (m_en_flag == CODE_OPEN_FAILED) {
return m_en_flag;
}
int p_en_flag = util::SystemUtils::check_file_encrypted(param_file);
if (p_en_flag == CODE_OPEN_FAILED) {
return p_en_flag;
}
unsigned char* aes_key = NULL;
unsigned char* aes_iv = NULL;
if (m_en_flag == 0 || p_en_flag == 0) {
// std::string key_str = util::crypto::Base64Utils::decode(std::string(key));
std::string key_str = baidu::base::base64::base64_decode(std::string(key));
int ret_check = 0;
if (m_en_flag == 0) {
ret_check = util::SystemUtils::check_key_match(key_str.c_str(), model_file);
if (ret_check != CODE_OK) {
LOGD("[M]check key failed in model_file");
return ret_check;
}
}
if (p_en_flag == 0) {
ret_check = util::SystemUtils::check_key_match(key_str.c_str(), param_file);
if (ret_check != CODE_OK) {
LOGD("[M]check key failed in param_file");
return ret_check;
}
}
aes_key = (unsigned char*) malloc(sizeof(unsigned char) * AES_GCM_KEY_LENGTH);
aes_iv = (unsigned char*) malloc(sizeof(unsigned char) * AES_GCM_IV_LENGTH);
memcpy(aes_key, key_str.c_str(), AES_GCM_KEY_LENGTH);
memcpy(aes_iv, key_str.c_str() + 16, AES_GCM_IV_LENGTH);
}
size_t pos = constant::MAGIC_NUMBER_LEN + constant::VERSION_LEN + constant::TAG_LEN;
// read encrypted model
unsigned char* model_dataptr = NULL;
size_t model_data_len = 0;
int ret_read_model = ioutil::read_with_pos(model_file, pos, &model_dataptr, &model_data_len);
if (ret_read_model != CODE_OK) {
LOGD("[M]read model failed");
return ret_read_model;
}
size_t model_plain_len = 0;
unsigned char* model_plain = NULL;
if (m_en_flag == 0) {
// decrypt model data
model_plain_len = model_data_len - AES_GCM_TAG_LENGTH;
model_plain = (unsigned char*) malloc(sizeof(unsigned char) * model_plain_len);
int ret_decrypt_model =
util::crypto::AesGcm::decrypt_aes_gcm(model_dataptr,
model_data_len,
aes_key,
aes_iv,
model_plain,
reinterpret_cast<int&>(model_plain_len));
free(model_dataptr);
if (ret_decrypt_model != CODE_OK) {
free(aes_key);
free(aes_iv);
free(model_plain);
LOGD("[M]decrypt model failed, decrypt ret = %d", ret_decrypt_model);
return CODE_AES_GCM_DECRYPT_FIALED;
}
} else {
model_plain = model_dataptr;
model_plain_len = model_data_len;
}
// read encrypted params
unsigned char* params_dataptr = NULL;
size_t params_data_len = 0;
int ret_read_params = ioutil::read_with_pos(param_file, pos, &params_dataptr, &params_data_len);
if (ret_read_params != CODE_OK) {
LOGD("[M]read params failed");
return ret_read_params;
}
size_t params_plain_len = 0;
unsigned char* params_plain = NULL;
if (p_en_flag == 0) {
// decrypt params data
params_plain_len = params_data_len - AES_GCM_TAG_LENGTH;
params_plain = (unsigned char*) malloc(sizeof(unsigned char) * params_plain_len);
int ret_decrypt_params =
util::crypto::AesGcm::decrypt_aes_gcm(params_dataptr,
params_data_len,
aes_key,
aes_iv,
params_plain,
reinterpret_cast<int&>(params_plain_len));
free(params_dataptr);
free(aes_key);
free(aes_iv);
if (ret_decrypt_params != CODE_OK) {
free(params_plain);
LOGD("[M]decrypt params failed, decrypt ret = %d", ret_decrypt_params);
return CODE_AES_GCM_DECRYPT_FIALED;
}
} else {
params_plain = params_dataptr;
params_plain_len = params_data_len;
}
LOGD("Prepare to set config");
config->SetModelBuffer(reinterpret_cast<const char*>(model_plain), model_plain_len,
reinterpret_cast<const char*>(params_plain), params_plain_len);
if (m_en_flag == 1) {
free(model_dataptr);
}
if (p_en_flag == 1) {
free(params_dataptr);
}
return CODE_OK;
}
#include <iostream>
#include <string>
#include <memory>
#include <vector>
#include <string.h>
#include "paddle_model_encrypt.h"
#include "model_code.h"
#include "../util/system_utils.h"
#include "../util/io_utils.h"
#include "../constant/constant_model.h"
#include "../util/crypto/aes_gcm.h"
#include "../util/crypto/sha256_utils.h"
#include "../util/crypto/base64.h"
#include "../util/log.h"
std::string paddle_generate_random_key() {
std::string tmp = util::SystemUtils::random_key_iv(AES_GCM_KEY_LENGTH);
// return util::crypto::Base64Utils::encode(tmp);
return baidu::base::base64::base64_encode(tmp);
}
int paddle_encrypt_dir(const char* keydata, const char* src_dir, const char* dst_dir) {
std::vector<std::string> files;
int ret_files = ioutil::read_dir_files(src_dir, files);
if (ret_files == -1) {
return CODE_NOT_EXIST_DIR;
}
if (ret_files == 0) {
return CODE_FILES_EMPTY_WITH_DIR;
}
// check model.yml, __model__, __params__ exist or not
if (util::SystemUtils::check_pattern_exist(files, "model.yml")) {
return CODE_MODEL_YML_FILE_NOT_EXIST;
}
if (util::SystemUtils::check_pattern_exist(files, "__model__")) {
return CODE_MODEL_FILE_NOT_EXIST;
}
if (util::SystemUtils::check_pattern_exist(files, "__params__")) {
return CODE_PARAMS_FILE_NOT_EXIST;
}
std::string src_str(src_dir);
if (src_str[src_str.length() - 1] != '/') {
src_str.append("/");
}
std::string dst_str(dst_dir);
if (dst_str[dst_str.length() - 1] != '/') {
dst_str.append("/");
}
int ret = CODE_OK;
ret = ioutil::dir_exist_or_mkdir(dst_str.c_str());
for (int i = 0; i < files.size(); ++i) {
if (strcmp(files[i].c_str(), "__model__") == 0 || strcmp(files[i].c_str(), "__params__") == 0 || strcmp(files[i].c_str(), "model.yml") == 0) {
std::string infile = src_str + files[i];
std::string outfile = dst_str + files[i] + ".encrypted";
ret = paddle_encrypt_model(keydata, infile.c_str(), outfile.c_str());
} else {
std::string infile = src_str + files[i];
std::string outfile = dst_str + files[i];
ret = ioutil::read_file_to_file(infile.c_str(), outfile.c_str());
}
if (ret != CODE_OK) {
return ret;
}
}
files.clear();
return ret;
}
int paddle_encrypt_model(const char* keydata, const char* infile, const char* outfile) {
// std::string key_str = util::crypto::Base64Utils::decode(std::string(keydata));
std::string key_str = baidu::base::base64::base64_decode(std::string(keydata));
if (key_str.length() != 32) {
return CODE_KEY_LENGTH_ABNORMAL;
}
unsigned char* plain = NULL;
size_t plain_len = 0;
int ret_read = ioutil::read_file(infile, &plain, &plain_len);
if (ret_read != CODE_OK) {
return ret_read;
}
unsigned char* aes_key = (unsigned char*) malloc(sizeof(unsigned char) * AES_GCM_KEY_LENGTH);
unsigned char* aes_iv = (unsigned char*) malloc(sizeof(unsigned char) * AES_GCM_IV_LENGTH);
memcpy(aes_key, key_str.c_str(), AES_GCM_KEY_LENGTH);
memcpy(aes_iv, key_str.c_str() + 16, AES_GCM_IV_LENGTH);
unsigned char* cipher = (unsigned char*) malloc(sizeof(unsigned char) * (plain_len + AES_GCM_TAG_LENGTH));
size_t cipher_len = 0;
int ret_encrypt =
util::crypto::AesGcm::encrypt_aes_gcm(plain,
plain_len,
aes_key,
aes_iv,
cipher,
reinterpret_cast<int&>(cipher_len));
free(aes_key);
free(aes_iv);
if (ret_encrypt != CODE_OK) {
LOGD("[M]aes encrypt ret code: %d", ret_encrypt);
free(plain);
free(cipher);
return CODE_AES_GCM_ENCRYPT_FIALED;
}
std::string randstr = util::SystemUtils::random_str(constant::TAG_LEN);
std::string aes_key_iv(key_str);
std::string sha256_key_iv = util::crypto::SHA256Utils::sha256_string(aes_key_iv);
for (int i = 0; i < 64; ++i) {
randstr[i] = sha256_key_iv[i];
}
size_t header_len = constant::MAGIC_NUMBER_LEN + constant::VERSION_LEN + constant::TAG_LEN;
unsigned char* header = (unsigned char*) malloc(sizeof(unsigned char) * header_len);
memcpy(header, constant::MAGIC_NUMBER.c_str(), constant::MAGIC_NUMBER_LEN);
memcpy(header + constant::MAGIC_NUMBER_LEN, constant::VERSION.c_str(), constant::VERSION_LEN);
memcpy(header + constant::MAGIC_NUMBER_LEN + constant::VERSION_LEN, randstr.c_str(), constant::TAG_LEN);
int ret_write_file = ioutil::write_file(outfile, header, header_len);
ret_write_file = ioutil::append_file(outfile, cipher, cipher_len);
free(header);
free(cipher);
return ret_write_file;
}
#include <iostream>
#include "aes_gcm.h"
namespace util {
namespace crypto {
int AesGcm::aes_gcm_key(
const unsigned char* key,
const unsigned char* iv,
EVP_CIPHER_CTX* e_ctx,
EVP_CIPHER_CTX* d_ctx) {
int ret = 0;
if (e_ctx != NULL) {
ret = EVP_EncryptInit_ex(e_ctx, EVP_aes_256_gcm(), NULL, NULL, NULL);
if (ret != 1) {
return -1;
}
ret = EVP_CIPHER_CTX_ctrl(e_ctx, EVP_CTRL_GCM_SET_IVLEN, AES_GCM_IV_LENGTH, NULL);
if (ret != 1) {
return -2;
}
ret = EVP_EncryptInit_ex(e_ctx, NULL, NULL, key, iv);
if (ret != 1) {
return -3;
}
}
// initial decrypt ctx
if (d_ctx != NULL) {
ret = EVP_DecryptInit_ex(d_ctx, EVP_aes_256_gcm(), NULL, NULL, NULL);
if (!ret) {
return -1;
}
ret = EVP_CIPHER_CTX_ctrl(d_ctx, EVP_CTRL_GCM_SET_IVLEN, AES_GCM_IV_LENGTH, NULL);
if (!ret) {
return -2;
}
ret = EVP_DecryptInit_ex(d_ctx, NULL, NULL, key, iv);
if (!ret) {
return -3;
}
}
return 0;
}
int AesGcm::aes_gcm_key(
const std::string& key_hex,
const std::string& iv_hex,
EVP_CIPHER_CTX* e_ctx,
EVP_CIPHER_CTX* d_ctx) {
// check key_hex and iv_hex length
if (key_hex.length() != AES_GCM_KEY_LENGTH * 2
|| iv_hex.length() != AES_GCM_IV_LENGTH * 2) {
return -4;
}
unsigned char key[AES_GCM_KEY_LENGTH];
unsigned char iv[AES_GCM_IV_LENGTH];
int ret = Basic::hex_to_byte(key_hex, key);
if (ret < 0) {
return -5;
}
ret = Basic::hex_to_byte(iv_hex, iv);
if (ret < 0) {
return -5;
}
return aes_gcm_key(key, iv, e_ctx, d_ctx);
}
int AesGcm::encrypt_aes_gcm(
const unsigned char* plaintext,
const int& len,
const unsigned char* key,
const unsigned char* iv,
unsigned char* ciphertext,
int& out_len) {
EVP_CIPHER_CTX* ctx = NULL;
int ret = 0;
int update_len = 0;
int ciphertext_len = 0;
unsigned char tag_char[AES_GCM_TAG_LENGTH];
if (!(ctx = EVP_CIPHER_CTX_new())) {
return -1;
}
// initial context
ret = aes_gcm_key(key, iv, ctx, NULL);
if (ret) {
EVP_CIPHER_CTX_free(ctx);
return -1;
}
// encryption
ret = EVP_EncryptUpdate(ctx, ciphertext, &update_len, plaintext, len);
if (ret != 1) {
EVP_CIPHER_CTX_free(ctx);
return -2;
}
ciphertext_len = update_len;
ret = EVP_EncryptFinal_ex(ctx, ciphertext + ciphertext_len, &update_len);
if (1 != ret) {
EVP_CIPHER_CTX_free(ctx);
return -3;
}
ciphertext_len += update_len;
// Get the tags for authentication
ret = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, AES_GCM_TAG_LENGTH, tag_char);
if (1 != ret) {
EVP_CIPHER_CTX_free(ctx);
return -4;
}
EVP_CIPHER_CTX_free(ctx);
//append the tags to the end of encryption text
for (int i = 0; i < AES_GCM_TAG_LENGTH; ++i) {
ciphertext[ciphertext_len + i] = tag_char[i];
}
out_len = ciphertext_len + AES_GCM_TAG_LENGTH;
return 0;
}
int AesGcm::decrypt_aes_gcm(
const unsigned char* ciphertext,
const int& len,
const unsigned char* key,
const unsigned char* iv,
unsigned char* plaintext,
int& out_len) {
EVP_CIPHER_CTX* ctx = NULL;
int ret = 0;
int update_len = 0;
int cipher_len = 0;
int plaintext_len = 0;
unsigned char tag_char[AES_GCM_TAG_LENGTH];
// get the tag at the end of ciphertext
for (int i = 0; i < AES_GCM_TAG_LENGTH; ++i) {
tag_char[i] = ciphertext[len - AES_GCM_TAG_LENGTH + i];
}
cipher_len = len - AES_GCM_TAG_LENGTH;
// initial aes context
if (!(ctx = EVP_CIPHER_CTX_new())) {
return -1;
}
ret = aes_gcm_key(key, iv, NULL, ctx);
if (ret) {
EVP_CIPHER_CTX_free(ctx);
return -1;
}
// decryption
ret = EVP_DecryptUpdate(ctx, plaintext, &update_len, ciphertext, cipher_len);
if (ret != 1) {
EVP_CIPHER_CTX_free(ctx);
return -2;
}
plaintext_len = update_len;
// check if the tag is equal to the decrption tag
ret = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, AES_GCM_TAG_LENGTH, tag_char);
if (!ret) {
EVP_CIPHER_CTX_free(ctx);
// decrption failed
return -3;
}
ret = EVP_DecryptFinal_ex(ctx, plaintext + update_len, &update_len);
if (ret <= 0) {
EVP_CIPHER_CTX_free(ctx);
return -4;
}
plaintext_len += update_len;
EVP_CIPHER_CTX_free(ctx);
out_len = plaintext_len;
return 0;
}
} // namespace crypt
} // namespace common
#ifndef PADDLE_MODEL_PROTECT_UTIL_CRYPTO_AES_GCM_H
#define PADDLE_MODEL_PROTECT_UTIL_CRYPTO_AES_GCM_H
#include <iostream>
#include <openssl/aes.h>
#include <openssl/evp.h>
#include <string>
#include "basic.h"
namespace util {
namespace crypto {
// aes key 32 byte for 256 bit
#define AES_GCM_KEY_LENGTH 32
// aes tag 16 byte for 128 bit
#define AES_GCM_TAG_LENGTH 16
// aes iv 12 byte for 96 bit
#define AES_GCM_IV_LENGTH 16
class AesGcm {
public:
/**
* \brief initial aes-gcm-256 context use key & iv
*
* \note initial aes-gcm-256 context use key & iv. gcm mode
* will generate a tag(16 byte), so the ciphertext's length
* should be longer 16 byte than plaintext.
*
*
* \param plaintext plain text to be encrypted(in)
* \param len plain text's length(in)
* \param key aes key (in)
* \param iv aes iv (in)
* \param ciphertext encrypted text(out)
* \param out_len encrypted length(out)
*
* \return return 0 if successful
* -1 EVP_CIPHER_CTX_new or aes_gcm_key error
* -2 EVP_EncryptUpdate error
* -3 EVP_EncryptFinal_ex error
* -4 EVP_CIPHER_CTX_ctrl error
*/
static int encrypt_aes_gcm(
const unsigned char* plaintext,
const int& len,
const unsigned char* key,
const unsigned char* iv,
unsigned char* ciphertext,
int& out_len);
/**
* \brief encrypt using aes-gcm-256
*
* \note encrypt using aes-gcm-256
*
* \param ciphertext cipher text to be decrypted(in)
* \param len plain text's length(in)
* \param key aes key (in)
* \param iv aes iv (in)
* \param plaintext decrypted text(out)
* \param out_len decrypted length(out)
*
* \return return 0 if successful
* -1 EVP_CIPHER_CTX_new or aes_gcm_key error
* -2 EVP_DecryptUpdate error
* -3 EVP_CIPHER_CTX_ctrl error
* -4 EVP_DecryptFinal_ex error
*/
static int decrypt_aes_gcm(
const unsigned char* ciphertext,
const int& len,
const unsigned char* key,
const unsigned char* iv,
unsigned char* plaintext,
int& out_len);
private:
/**
* \brief initial aes-gcm-256 context use key & iv
*
* \note initial aes-gcm-256 context use key & iv
*
* \param key aes key (in)
* \param iv aes iv (in)
* \param e_ctx encryption context(out)
* \param d_ctx decryption context(out)
*
* \return return 0 if successful
* -1 EVP_xxcryptInit_ex error
* -2 EVP_CIPHER_CTX_ctrl error
* -3 EVP_xxcryptInit_ex error
*/
static int aes_gcm_key(
const unsigned char* key,
const unsigned char* iv,
EVP_CIPHER_CTX* e_ctx,
EVP_CIPHER_CTX* d_ctx);
/**
* \brief initial aes-gcm-256 context use key & iv
*
* \note initial aes-gcm-256 context use key & iv
*
* \param key aes key (in)
* \param iv aes iv (in)
* \param e_ctx encryption context(out)
* \param d_ctx decryption context(out)
*
* \return return 0 if successful
* -1 EVP_xxcryptInit_ex error
* -2 EVP_CIPHER_CTX_ctrl error
* -3 EVP_xxcryptInit_ex error
* -4 invalid key length or iv length
* -5 hex_to_byte error
*/
static int aes_gcm_key(
const std::string& key_hex,
const std::string& iv_hex,
EVP_CIPHER_CTX* e_ctx,
EVP_CIPHER_CTX* d_ctx);
};
} // namespace crypt
} // namespace common
#endif // PADDLE_MODEL_PROTECT_UTIL_CRYPTO_AES_GCM_H
\ No newline at end of file
#include "base64.h"
using std::string;
namespace baidu {
namespace base {
namespace base64 {
namespace {
const string base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
inline bool is_base64(unsigned char c) {
return isalnum(c) || (c == '+') || (c == '/');
}
inline size_t encode_len(size_t input_len) {
return (input_len + 2) / 3 * 4;
}
void encode_char_array(unsigned char *encode_block, const unsigned char *decode_block) {
encode_block[0] = (decode_block[0] & 0xfc) >> 2;
encode_block[1] = ((decode_block[0] & 0x03) << 4) + ((decode_block[1] & 0xf0) >> 4);
encode_block[2] = ((decode_block[1] & 0x0f) << 2) + ((decode_block[2] & 0xc0) >> 6);
encode_block[3] = decode_block[2] & 0x3f;
}
void decode_char_array(unsigned char *encode_block, unsigned char *decode_block) {
for (int i = 0; i < 4; ++i) {
encode_block[i] = base64_chars.find(encode_block[i]);
}
decode_block[0] = (encode_block[0] << 2) + ((encode_block[1] & 0x30) >> 4);
decode_block[1] = ((encode_block[1] & 0xf) << 4) + ((encode_block[2] & 0x3c) >> 2);
decode_block[2] = ((encode_block[2] & 0x3) << 6) + encode_block[3];
}
}
string base64_encode(const string& input) {
string output;
size_t i = 0;
unsigned char decode_block[3];
unsigned char encode_block[4];
for (string::size_type len = 0; len != input.size(); ++len) {
decode_block[i++] = input[len];
if (i == 3) {
encode_char_array(encode_block, decode_block);
for (i = 0; i < 4; ++i) {
output += base64_chars[encode_block[i]];
}
i = 0;
}
}
if (i > 0) {
for (size_t j = i; j < 3; ++j) {
decode_block[j] = '\0';
}
encode_char_array(encode_block, decode_block);
for (size_t j = 0; j < i + 1; ++j) {
output += base64_chars[encode_block[j]];
}
while (i++ < 3) {
output += '=';
}
}
return output;
}
string base64_decode(const string& encoded_string) {
int in_len = encoded_string.size();
int i = 0;
int len = 0;
unsigned char encode_block[4];
unsigned char decode_block[3];
string output;
while (in_len-- && (encoded_string[len] != '=') && is_base64(encoded_string[len])) {
encode_block[i++] = encoded_string[len];
len++;
if (i == 4) {
decode_char_array(encode_block, decode_block);
for (int j = 0; j < 3; ++j) {
output += decode_block[j];
}
i = 0;
}
}
if (i > 0) {
for (int j = i; j < 4; ++j) {
encode_block[j] = 0;
}
decode_char_array(encode_block, decode_block);
for (int j = 0; j < i - 1; ++j) {
output += decode_block[j];
}
}
return output;
}
}
}
}
\ No newline at end of file
#include <vector>
#include <string>
#ifndef PADDLE_MODEL_PROTECT_UTIL_CRYPTO_BASE64_UTILS_H
#define PADDLE_MODEL_PROTECT_UTIL_CRYPTO_BASE64_UTILS_H
namespace baidu {
namespace base {
namespace base64 {
std::string base64_encode(const std::string& input);
std::string base64_decode(const std::string& input);
}
}
}
#endif //PADDLE_MODEL_PROTECT_BASE64_UTILS_H
#include "basic.h"
namespace util {
namespace crypto {
int Basic::byte_to_hex(
const unsigned char* in_byte,
int len,
std::string& out_hex) {
std::ostringstream oss;
oss << std::hex << std::setfill('0');
for (int i = 0; i < len; ++i) {
oss << std::setw(2) << int(in_byte[i]);
}
out_hex = oss.str();
return 0;
}
int Basic::hex_to_byte(
const std::string& in_hex,
unsigned char* out_byte) {
int i = 0;
int j = 0;
int len = in_hex.length() / 2;
const unsigned char* hex;
if (in_hex.length() % 2 != 0 || out_byte == NULL) {
return -1;
}
hex = (unsigned char*) in_hex.c_str();
for (; j < len; i += 2, ++j) {
unsigned char high = hex[i];
unsigned char low = hex[i + 1];
if (high >= '0' && high <= '9') {
high = high - '0';
} else if (high >= 'A' && high <= 'F') {
high = high - 'A' + 10;
} else if (high >= 'a' && high <= 'f') {
high = high - 'a' + 10;
} else {
return -2;
}
if (low >= '0' && low <= '9') {
low = low - '0';
} else if (low >= 'A' && low <= 'F') {
low = low - 'A' + 10;
} else if (low >= 'a' && low <= 'f') {
low = low - 'a' + 10;
} else {
return -2;
}
out_byte[j] = high << 4 | low;
}
return 0;
}
int Basic::random(unsigned char* random, int len) {
std::random_device rd;
int i = 0;
if (len <= 0 || random == NULL) {
return -1;
}
for (; i < len; ++i) {
random[i] = rd() % 256;
}
return 0;
}
}
} // namespace common
\ No newline at end of file
#ifndef PADDLE_MODEL_PROTECT_UTIL_BASIC_H
#define PADDLE_MODEL_PROTECT_UTIL_BASIC_H
#include <iomanip>
#include <iostream>
#include <random>
#include <string>
#include <sstream>
namespace util {
namespace crypto {
class Basic {
public:
/**
* \brief byte to hex
*
* \note byte to hex.
*
*
* \param in_byte byte array(in)
* \param len byte array length(in)
* \param out_hex the hex string(in)
*
*
* \return return 0 if successful
*/
static int byte_to_hex(
const unsigned char* in_byte,
int len,
std::string& out_hex);
/**
* \brief hex to byte
*
* \note hex to byte.
*
*
* \param in_hex the hex string(in)
* \param out_byte byte array(out)
*
* \return return 0 if successful
* -1 invalid in_hex
*/
static int hex_to_byte(
const std::string& in_hex,
unsigned char* out_byte);
/**
* \brief get random char for length
*
* \note get random char for length
*
*
* \param array to be random(out)
* \param len array length(in)
*
* \return return 0 if successful
* -1 invalid parameters
*/
static int random(
unsigned char* random,
int len);
};
}
} // namespace common
#endif // PADDLE_MODEL_PROTECT_UTIL_BASIC_H
#include "sha256_utils.h"
#include <iomanip>
#include <stdio.h>
#include <openssl/sha.h>
#include <sstream>
namespace util {
namespace crypto {
void SHA256Utils::sha256(const void* data, size_t len, unsigned char* md) {
SHA256_CTX sha_ctx = {};
SHA256_Init(&sha_ctx);
SHA256_Update(&sha_ctx, data, len);
SHA256_Final(md, &sha_ctx);
}
std::vector<unsigned char> SHA256Utils::sha256(const void* data, size_t len) {
std::vector<unsigned char> md(32);
sha256(data, len, &md[0]);
return md;
}
std::vector<unsigned char> SHA256Utils::sha256(const std::vector<unsigned char>& data) {
return sha256(&data[0], data.size());
}
std::string SHA256Utils::sha256_string(const void* data, size_t len) {
std::vector<unsigned char> md = sha256(data, len);
std::ostringstream oss;
oss << std::hex << std::setfill('0');
for (unsigned char c : md) {
oss << std::setw(2) << int(c);
}
return oss.str();
}
std::string SHA256Utils::sha256_string(const std::vector<unsigned char>& data) {
return sha256_string(&data[0], data.size());
}
std::string SHA256Utils::sha256_string(const std::string& string) {
return sha256_string(string.c_str(), string.size());
}
std::string SHA256Utils::sha256_file(const std::string& path) {
FILE* file = fopen(path.c_str(), "rb");
if (!file) {
return "";
}
unsigned char hash[SHA256_DIGEST_LENGTH];
SHA256_CTX sha_ctx = {};
SHA256_Init(&sha_ctx);
const int size = 32768;
void* buffer = malloc(size);
if (!buffer) {
fclose(file);
return "";
}
int read = 0;
while ((read = fread(buffer, 1, size, file))) {
SHA256_Update(&sha_ctx, buffer, read);
}
SHA256_Final(hash, &sha_ctx);
std::ostringstream oss;
oss << std::hex << std::setfill('0');
for (unsigned char c : hash) {
oss << std::setw(2) << int(c);
}
fclose(file);
free(buffer);
return oss.str();
}
}
}
#include <vector>
#include <string>
#ifndef PADDLE_MODEL_PROTECT_UTIL_CRYPTO_SHA256_UTILS_H
#define PADDLE_MODEL_PROTECT_UTIL_CRYPTO_SHA256_UTILS_H
namespace util {
namespace crypto {
class SHA256Utils {
public:
static void sha256(const void* data, size_t len, unsigned char* md);
static std::vector<unsigned char> sha256(const void* data, size_t len);
static std::vector<unsigned char> sha256(const std::vector<unsigned char>& data);
static std::string sha256_string(const void* data, size_t len);
static std::string sha256_string(const std::vector<unsigned char>& data);
static std::string sha256_string(const std::string& string);
static std::string sha256_file(const std::string& path);
};
}
}
#endif //PADDLE_MODEL_PROTECT_UTIL_CRYPTO_SHA256_UTILS_H
#ifdef linux
#include <unistd.h>
#include <dirent.h>
#endif
#ifdef WIN32
#include <windows.h>
#include <io.h>
#endif
#include <iostream>
#include <string.h>
#include <sys/stat.h>
#include <sys/types.h>
#include "io_utils.h"
#include "model_code.h"
#include "log.h"
namespace ioutil {
int read_file(const char* file_path, unsigned char** dataptr, size_t* sizeptr) {
FILE* fp = NULL;
fp = fopen(file_path, "rb");
if (fp == NULL) {
LOGD("[M]open file(%s) failed", file_path);
return CODE_OPEN_FAILED;
}
fseek(fp, 0, SEEK_END);
*sizeptr = ftell(fp);
*dataptr = (unsigned char*) malloc(sizeof(unsigned char) * (*sizeptr));
fseek(fp, 0, SEEK_SET);
fread(*dataptr, 1, *sizeptr, fp);
fclose(fp);
return CODE_OK;
}
int read_with_pos_and_length(const char* file_path, unsigned char* dataptr, size_t pos, size_t length) {
if (dataptr == NULL) {
LOGD("Read file pos dataptr = NULL");
return CODE_READ_FILE_PTR_IS_NULL;
}
FILE* fp = NULL;
if ((fp = fopen(file_path, "rb")) == NULL) {
LOGD("[M]open file(%s) failed", file_path);
return CODE_OPEN_FAILED;
}
fseek(fp, pos, SEEK_SET);
fread(dataptr, 1, length, fp);
fclose(fp);
return CODE_OK;
}
int read_with_pos(const char* file_path, size_t pos, unsigned char** dataptr, size_t* sizeptr) {
FILE* fp = NULL;
if ((fp = fopen(file_path, "rb")) == NULL) {
LOGD("[M]open file(%s) failed when read_with_pos", file_path);
return CODE_OPEN_FAILED;
}
fseek(fp, 0, SEEK_END);
size_t filesize = ftell(fp);
*sizeptr = filesize - pos;
*dataptr = (unsigned char*) malloc(sizeof(unsigned char) * (filesize - pos));
fseek(fp, pos, SEEK_SET);
fread(*dataptr, 1, filesize - pos, fp);
fclose(fp);
return CODE_OK;
}
int write_file(const char* file_path, const unsigned char* dataptr, size_t sizeptr) {
FILE* fp = NULL;
if ((fp = fopen(file_path, "wb")) == NULL) {
LOGD("[M]open file(%s) failed", file_path);
return CODE_OPEN_FAILED;
}
fwrite(dataptr, 1, sizeptr, fp);
fclose(fp);
return CODE_OK;
}
int append_file(const char* file_path, const unsigned char* data, size_t len) {
FILE* fp = fopen(file_path, "ab+");
if (fp == NULL) {
LOGD("[M]open file(%s) failed when append_file", file_path);
return CODE_OPEN_FAILED;
}
fwrite(data, sizeof(char), len, fp);
fclose(fp);
return CODE_OK;
}
size_t read_file_size(const char* file_path) {
FILE* fp = NULL;
fp = fopen(file_path, "rb");
if (fp == NULL) {
LOGD("[M]open file(%s) failed when read_file_size", file_path);
return 0;
}
fseek(fp, 0, SEEK_END);
size_t filesize = ftell(fp);
fclose(fp);
return filesize;
}
int read_file_to_file(const char* src_path, const char* dst_path) {
FILE* infp = NULL;
if ((infp = fopen(src_path, "rb")) == NULL) {
LOGD("[M]read src file failed when read_file_to_file");
return CODE_OPEN_FAILED;
}
fseek(infp, 0, SEEK_END);
size_t insize = ftell(infp);
char* content = (char*) malloc(sizeof(char) * insize);
fseek(infp, 0, SEEK_SET);
fread(content, 1, insize, infp);
fclose(infp);
FILE* outfp = NULL;
if ((outfp = fopen(dst_path, "wb")) == NULL) {
LOGD("[M]open dst file failed when read_file_to_file");
return CODE_OPEN_FAILED;
}
fwrite(content, 1, insize, outfp);
fclose(outfp);
free(content);
return CODE_OK;
}
int read_dir_files(const char* dir_path, std::vector<std::string>& files) {
#ifdef linux
struct dirent* ptr;
DIR* dir = NULL;
dir = opendir(dir_path);
if (dir == NULL) {
return -1; // CODE_NOT_EXIST_DIR
}
while ((ptr = readdir(dir)) != NULL) {
if (strcmp(ptr->d_name, ".") != 0 && strcmp(ptr->d_name, "..") != 0) {
files.push_back(ptr->d_name);
}
}
closedir(dir);
#endif
#ifdef WIN32
intptr_t handle;
struct _finddata_t fileinfo;
std::string tmp_dir(dir_path);
std::string::size_type idx = tmp_dir.rfind("\\*");
if (idx == std::string::npos || idx != tmp_dir.length() - 1)
{
tmp_dir.append("\\*");
}
handle = _findfirst(tmp_dir.c_str(), &fileinfo);
if (handle == -1) {
return -1;
}
do {
std::cout << "File name = " << fileinfo.name << std::endl;
if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0) {
files.push_back(fileinfo.name);
}
} while (!_findnext(handle, &fileinfo));
std::cout << files.size() << std::endl;
for (size_t i = 0; i < files.size(); i++)
{
std::cout << files[i] << std::endl;
}
_findclose(handle);
#endif
return files.size();
}
int dir_exist_or_mkdir(const char* dir) {
#ifdef WIN32
if (CreateDirectory(dir, NULL)) {
// return CODE_OK;
} else {
return CODE_MKDIR_FAILED;
}
#endif
#ifdef linux
if (access(dir, 0) != 0) {
mkdir(dir, S_IRWXU | S_IRWXG | S_IRWXO);
}
#endif
return CODE_OK;
}
}
#include <iostream>
#include <memory>
#include <vector>
#include <string>
#ifndef PADDLE_MODEL_PROTECT_IO_UTILS_H
#define PADDLE_MODEL_PROTECT_IO_UTILS_H
namespace ioutil {
int read_file(const char* file_path, unsigned char** dataptr, size_t* sizeptr);
int read_with_pos_and_length(const char* file_path, unsigned char* dataptr, size_t pos, size_t length);
int read_with_pos(const char* file_path, size_t pos, unsigned char** dataptr, size_t* sizeptr);
int write_file(const char* file_path, const unsigned char* dataptr, size_t sizeptr);
int append_file(const char* file_path, const unsigned char* data, size_t len);
size_t read_file_size(const char* file_path);
int read_file_to_file(const char* src_path, const char* dst_path);
int dir_exist_or_mkdir(const char* dir);
/**
* @return files.size()
*/
int read_dir_files(const char* dir_path, std::vector<std::string>& files);
}
#endif //PADDLE_MODEL_PROTECT_IO_UTILS_H
#ifndef PADDLE_MODEL_PROTECT_UTIL_LOG_H
#define PADDLE_MODEL_PROTECT_UTIL_LOG_H
#include <stdio.h>
#define LOGD(fmt,...)\
printf("{%s:%u}:" fmt "\n", __FUNCTION__, __LINE__, ##__VA_ARGS__)
#endif //PADDLE_MODEL_PROTECT_UTIL_LOG_H
#include <sys/timeb.h>
#include <string.h>
#include <model_code.h>
#include <algorithm>
#include "system_utils.h"
#include "crypto/basic.h"
#include "crypto/sha256_utils.h"
#include "io_utils.h"
#include "log.h"
#include "../constant/constant_model.h"
const char alphabet[] = "abcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*(){}[]<>?~";
namespace util {
int SystemUtils::intN(int n) {
return rand() % n;
}
std::string SystemUtils::random_key_iv(int len) {
unsigned char* tmp = (unsigned char*) malloc(sizeof(unsigned char) * len);
int ret = util::crypto::Basic::random(tmp, len);
std::string tmp_str(reinterpret_cast<const char*>(tmp), len);
free(tmp);
return tmp_str;
}
std::string SystemUtils::random_str(int len) {
unsigned char* tmp = (unsigned char*) malloc(sizeof(unsigned char) * len);
int ret = util::crypto::Basic::random(tmp, len);
std::string tmp_str(reinterpret_cast<const char*>(tmp), len);
free(tmp);
return tmp_str;
}
int SystemUtils::check_key_match(const char* key, const char* filepath) {
std::string aes_key_iv(key);
std::string sha256_aes_key_iv = util::crypto::SHA256Utils::sha256_string(aes_key_iv);
unsigned char* data_pos = (unsigned char*) malloc(sizeof(unsigned char) * 64);
int ret =
ioutil::read_with_pos_and_length(filepath, data_pos, constant::MAGIC_NUMBER_LEN + constant::VERSION_LEN, 64);
if (ret != CODE_OK) {
LOGD("[M]read file failed when check key");
return ret;
}
std::string check_str((char*) data_pos, 64);
free(data_pos);
if (strcmp(sha256_aes_key_iv.c_str(), check_str.c_str()) != 0) {
return CODE_KEY_NOT_MATCH;
}
return CODE_OK;
}
/**
*
* @param filepath
* @return 0 - file encrypted 1 - file unencrypted
*/
int SystemUtils::check_file_encrypted(const char* filepath) {
size_t read_len = constant::MAGIC_NUMBER_LEN + constant::VERSION_LEN;
unsigned char* data_pos = (unsigned char*) malloc(sizeof(unsigned char) * read_len);
if (ioutil::read_with_pos_and_length(filepath, data_pos, 0, read_len) != CODE_OK) {
LOGD("check file failed when read %s(file)", filepath);
return CODE_OPEN_FAILED;
}
std::string tag(constant::MAGIC_NUMBER);
tag.append(constant::VERSION);
int ret_cmp = strcmp(tag.c_str(), (const char*) data_pos) == 0 ? 0 : 1;
free(data_pos);
return ret_cmp;
}
int SystemUtils::check_pattern_exist(const std::vector<std::string>& vecs, const std::string& pattern) {
if (std::find(vecs.begin(), vecs.end(), pattern) == vecs.end()) {
return -1; // not exist
} else {
return 0; // exist
}
}
}
#include <string>
#include <vector>
#ifndef PADDLE_MODEL_PROTECT_SYSTEM_UTIL_H
#define PADDLE_MODEL_PROTECT_SYSTEM_UTIL_H
namespace util {
class SystemUtils {
public:
static std::string random_key_iv(int len);
static std::string random_str(int len);
static int check_key_match(const char* key, const char* filepath);
static int check_file_encrypted(const char* filepath);
static int check_pattern_exist(const std::vector<std::string>& vecs, const std::string& pattern);
private:
inline static int intN(int n);
};
}
#endif //PADDLE_MODEL_PROTECT_SYSTEM_UTIL_H
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册