diff --git a/deploy/encryption/include/model_code.h b/deploy/encryption/include/model_code.h index 9c0092e1594521bacfec2ff7450976e47572e57c..dc09445cfd43023c3074ce2814cad02af81695b1 100644 --- a/deploy/encryption/include/model_code.h +++ b/deploy/encryption/include/model_code.h @@ -18,7 +18,8 @@ extern "C" { CODE_MODEL_FILE_NOT_EXIST = 108, CODE_PARAMS_FILE_NOT_EXIST = 109, CODE_MODEL_YML_FILE_NOT_EXIST = 110, - CODE_MKDIR_FAILED = 111 + CODE_MKDIR_FAILED = 111, + CODE_MALLOC_FAILED = 112 }; diff --git a/deploy/encryption/src/safeapi/paddle_model_decrypt.cpp b/deploy/encryption/src/safeapi/paddle_model_decrypt.cpp index 8c0186b44aae91786fff68f038bf0efca3f8dd38..7eb881dbdf9dfb595775c278f3388f46b8aa0997 100644 --- a/deploy/encryption/src/safeapi/paddle_model_decrypt.cpp +++ b/deploy/encryption/src/safeapi/paddle_model_decrypt.cpp @@ -18,9 +18,10 @@ int paddle_check_file_encrypted(const char* file_path) { return util::SystemUtils::check_file_encrypted(file_path); } -std::string decrypt_file(const char* file_path, const char* key) { - int ret = paddle_check_file_encrypted(file_path); +std::string decrypt_file_with_code(const char* file_path, const char* key, int* decrypt_code) { +int ret = paddle_check_file_encrypted(file_path); if (ret != CODE_OK) { + *decrypt_code = ret; LOGD("[M]check file encrypted failed, code: %d", ret); return std::string(); } @@ -29,6 +30,7 @@ std::string decrypt_file(const char* file_path, const char* key) { std::string key_str = baidu::base::base64::base64_decode(std::string(key)); int ret_check = util::SystemUtils::check_key_match(key_str.c_str(), file_path); if (ret_check != CODE_OK) { + *decrypt_code = ret_check; LOGD("[M]check key failed in decrypt_file, code: %d", ret_check); return std::string(); } @@ -44,6 +46,7 @@ std::string decrypt_file(const char* file_path, const char* key) { size_t data_len = 0; int ret_read_data = ioutil::read_with_pos(file_path, pos, &dataptr, &data_len); if (ret_read_data != CODE_OK) { + *decrypt_code = ret_read_data; LOGD("[M]read file failed, code = %d", ret_read_data); return std::string(); } @@ -51,6 +54,11 @@ std::string decrypt_file(const char* file_path, const char* key) { // decrypt model data size_t model_plain_len = data_len - AES_GCM_TAG_LENGTH; unsigned char* model_plain = (unsigned char*) malloc(sizeof(unsigned char) * model_plain_len); + if (model_plain == NULL) { + *decrypt_code = CODE_MALLOC_FAILED; + LOGD("model_plain malloc failed(decrypt_file), code: %d", CODE_MALLOC_FAILED); + return std::string(); + } int ret_decrypt_file = util::crypto::AesGcm::decrypt_aes_gcm( @@ -64,12 +72,23 @@ std::string decrypt_file(const char* file_path, const char* key) { free(aes_key); free(aes_iv); if (ret_decrypt_file != CODE_OK) { + *decrypt_code = ret_decrypt_file; free(model_plain); LOGD("[M]decrypt file failed, decrypt ret = %d", ret_decrypt_file); return std::string(); } - std::string result((const char*)model_plain); + std::string result((const char*)model_plain, (const char*)model_plain + model_plain_len); free(model_plain); + *decrypt_code = CODE_OK; + return result; +} + +std::string decrypt_file(const char* file_path, const char* key) { + int decrypt_code = 0; + std::string result = decrypt_file_with_code(file_path, key, &decrypt_code); + if (decrypt_code != CODE_OK) { + LOGD("[M]decrypt file failed(decrypt_file), decrypt ret = %d", decrypt_code); + } return result; } @@ -139,6 +158,10 @@ int paddle_security_load_model( // decrypt model data model_plain_len = model_data_len - AES_GCM_TAG_LENGTH; model_plain = (unsigned char*) malloc(sizeof(unsigned char) * model_plain_len); + if (model_plain == NULL) { + LOGD("model_plain malloc failed"); + return CODE_MALLOC_FAILED; + } int ret_decrypt_model = util::crypto::AesGcm::decrypt_aes_gcm(model_dataptr, @@ -176,6 +199,10 @@ int paddle_security_load_model( // decrypt params data params_plain_len = params_data_len - AES_GCM_TAG_LENGTH; params_plain = (unsigned char*) malloc(sizeof(unsigned char) * params_plain_len); + if (params_plain == NULL) { + LOGD("params_plain malloc failed"); + return CODE_MALLOC_FAILED; + } int ret_decrypt_params = util::crypto::AesGcm::decrypt_aes_gcm(params_dataptr, @@ -202,13 +229,8 @@ int paddle_security_load_model( config->SetModelBuffer(reinterpret_cast(model_plain), model_plain_len, reinterpret_cast(params_plain), params_plain_len); - if (m_en_flag == 1) { - free(model_dataptr); - } - - if (p_en_flag == 1) { - free(params_dataptr); - } + free(model_plain); + free(params_plain); return CODE_OK; } diff --git a/deploy/encryption/src/util/io_utils.cpp b/deploy/encryption/src/util/io_utils.cpp index dd3c2971c2f737c236efb4729e2d52afd632ef1c..34a79607c26f8046e731e62481c725a0d34023f1 100644 --- a/deploy/encryption/src/util/io_utils.cpp +++ b/deploy/encryption/src/util/io_utils.cpp @@ -28,6 +28,10 @@ int read_file(const char* file_path, unsigned char** dataptr, size_t* sizeptr) { fseek(fp, 0, SEEK_END); *sizeptr = ftell(fp); *dataptr = (unsigned char*) malloc(sizeof(unsigned char) * (*sizeptr)); + if (*dataptr == NULL) { + LOGD("malloc failed when read file"); + return CODE_MALLOC_FAILED; + } fseek(fp, 0, SEEK_SET); fread(*dataptr, 1, *sizeptr, fp); @@ -68,6 +72,10 @@ int read_with_pos(const char* file_path, size_t pos, unsigned char** dataptr, si *sizeptr = filesize - pos; *dataptr = (unsigned char*) malloc(sizeof(unsigned char) * (filesize - pos)); + if (*dataptr == NULL) { + LOGD("malloc failed when read file"); + return CODE_MALLOC_FAILED; + } fseek(fp, pos, SEEK_SET); fread(*dataptr, 1, filesize - pos, fp); fclose(fp); @@ -172,17 +180,17 @@ int read_dir_files(const char* dir_path, std::vector& files) { } do { - std::cout << "File name = " << fileinfo.name << std::endl; + // std::cout << "File name = " << fileinfo.name << std::endl; if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0) { files.push_back(fileinfo.name); } } while (!_findnext(handle, &fileinfo)); -std::cout << files.size() << std::endl; + /* std::cout << files.size() << std::endl; for (size_t i = 0; i < files.size(); i++) { std::cout << files[i] << std::endl; - } + } */ _findclose(handle); #endif