From 3168be11242b79fb67aa7fc0ea89152da1cf24bc Mon Sep 17 00:00:00 2001 From: Yanzhan Yang Date: Mon, 8 Jul 2019 13:03:57 +0800 Subject: [PATCH] enhance quantification tool to dump float32 params. (#1731) --- tools/quantification/convert.cpp | 200 ++++++++++++++++++++++++++++--- 1 file changed, 181 insertions(+), 19 deletions(-) diff --git a/tools/quantification/convert.cpp b/tools/quantification/convert.cpp index 282b22073f..dc341d5d5a 100644 --- a/tools/quantification/convert.cpp +++ b/tools/quantification/convert.cpp @@ -68,7 +68,7 @@ std::shared_ptr loadParams(const std::string &model_path) { } -void LoadWithDump(const paddle_mobile::framework::VarDesc &var_desc, char **dataP, FILE *out_file) { +void LoadWithDumpForInt8(const paddle_mobile::framework::VarDesc &var_desc, char **dataP, FILE *out_file) { // 1. version uint32_t version = *reinterpret_cast(*dataP); @@ -182,8 +182,7 @@ void LoadWithDump(const paddle_mobile::framework::VarDesc &var_desc, char **data } void -quantificate_combined(const std::string &model_path, const std::string ¶m_path, const std::string ¶m_min_path) { - +quantificate_combined_int8(const std::string &model_path, const std::string ¶m_path, const std::string ¶m_min_path) { auto program = loadParams(model_path); char *origin_data = Get_binary_data(param_path); char *data = origin_data; @@ -194,17 +193,15 @@ quantificate_combined(const std::string &model_path, const std::string ¶m_pa if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { continue; } - LoadWithDump(*var_desc, &data, out_file); + LoadWithDumpForInt8(*var_desc, &data, out_file); } } } fclose(out_file); delete origin_data; - } -void quantificate_seperated(const std::string model_dir, const std::string param_min_path) { - +void quantificate_seperated_int8(const std::string model_dir, const std::string param_min_path) { auto program = loadParams(model_dir + "/__model__"); std::string shell_command = "mkdir " + param_min_path; @@ -220,25 +217,180 @@ void quantificate_seperated(const std::string model_dir, const std::string param FILE *out_file = fopen(file_name.c_str(), "wb"); char *origin_data = Get_binary_data(model_dir + "/" + var_desc->Name()); char *data = origin_data; - LoadWithDump(*var_desc, &data, out_file); + LoadWithDumpForInt8(*var_desc, &data, out_file); delete origin_data; fclose(out_file); } } } +} + +void LoadWithDumpForFloat32(const paddle_mobile::framework::VarDesc &var_desc, char **dataP, FILE *out_file) { + // 1. version + uint32_t version = *reinterpret_cast(*dataP); + + // write version + fwrite(&version, kSize32, 1, out_file); + + *dataP += kSize32; + + // 2 Lod information + auto *lod_level_ptr = new uint64_t(); + memcpy(lod_level_ptr, *dataP, kSize64); + + uint64_t lod_level = 0; + // write lod Information + fwrite(&lod_level, kSize64, 1, out_file); + delete lod_level_ptr; + + *dataP += kSize64; + + for (uint64_t i = 0; i < lod_level; ++i) { + uint64_t size = *reinterpret_cast(*dataP); + // write lod size + fwrite(&size, kSize64, 1, out_file); + (*dataP) += kSize64; + + std::vector tmp(size / sizeof(size_t)); + for (unsigned long &k : tmp) { + k = *reinterpret_cast(*dataP); + (*dataP) += sizeof(size_t); + } + // write lod size vector + fwrite(&tmp, sizeof(size_t), tmp.size(), out_file); + } + + // 3. tensor version + uint32_t tensor_version = *reinterpret_cast(*dataP); + // write tensor version + fwrite(&tensor_version, kSize32, 1, out_file); + (*dataP) += kSize32; + + // 4. tensor desc + int32_t size = *reinterpret_cast(*dataP); + // write tensor desc + fwrite(&size, sizeof(int32_t), 1, out_file); + (*dataP) += sizeof(int32_t); + + std::unique_ptr buf(new char[size]); + for (int m = 0; m < size; ++m) { + buf.get()[m] = (*dataP)[m]; + } + + fwrite(buf.get(), sizeof(char), static_cast(size), out_file); + (*dataP) += (sizeof(char) * size); + + const paddle_mobile::framework::TensorDesc &desc = var_desc.Tensor_desc(); + int memory_size = 1; + for (auto l : desc.Dims()) { + memory_size *= l; + } + + void *memory = nullptr; + int type_size = 0; + switch (desc.DataType()) { + case paddle_mobile::framework::VARTYPE_TYPE_FP16: + type_size = 2; + break; + case paddle_mobile::framework::VARTYPE_TYPE_FP32: + type_size = 4; + break; + case paddle_mobile::framework::VARTYPE_TYPE_FP64: + type_size = 8; + break; + case paddle_mobile::framework::VARTYPE_TYPE_INT32: + type_size = 4; + break; + case paddle_mobile::framework::VARTYPE_TYPE_INT64: + type_size = 8; + break; + case paddle_mobile::framework::VARTYPE_TYPE_BOOL: + type_size = 1; + break; + default: + break; + } + size_t tensorSize = sizeof(char) * memory_size * type_size; + + memory = new char[tensorSize]; + + for (int n = 0; n < tensorSize; ++n) { + static_cast(memory)[n] = (*dataP)[n]; + } + *dataP += tensorSize; + + // for float 32 + float min_value = std::numeric_limits::max(); + float max_value = std::numeric_limits::min(); + + for (int k = 0; k < memory_size; ++k) { + min_value = std::min(min_value, static_cast (memory)[k]); + max_value = std::max(max_value, static_cast (memory)[k]); + } + float diff = 0.0; + for (int g = 0; g < memory_size; ++g) { + float value = static_cast (memory)[g]; + auto factor = (uint8_t) round((value - min_value) / (max_value - min_value) * 255); + float value_quantized = min_value + (factor / 255.0) * (max_value - min_value); + diff += abs(value - value_quantized); + fwrite(&value_quantized, sizeof(float), 1, out_file); + } + std::cout << "avg diff caused by quantization for var " << var_desc.Name() << " is: " << diff << std::endl; } +void +quantificate_combined_float32(const std::string &model_path, const std::string ¶m_path, const std::string ¶m_min_path) { + auto program = loadParams(model_path); + char *origin_data = Get_binary_data(param_path); + char *data = origin_data; + FILE *out_file = fopen(param_min_path.c_str(), "wb"); + for (const auto &block : program->Blocks()) { + for (const auto &var_desc : block->Vars()) { + if (var_desc->Persistable()) { + if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { + continue; + } + LoadWithDumpForFloat32(*var_desc, &data, out_file); + } + } + } + fclose(out_file); + delete origin_data; +} -int main(int argc, char **argv) { +void quantificate_seperated_float32(const std::string model_dir, const std::string param_min_path) { + auto program = loadParams(model_dir + "/__model__"); - const std::string kNoteEg = "( eg: ./quantify 1 your_combined_model_path output_path or ./quantify 0 your_seperated_model_path output_path)"; + std::string shell_command = "mkdir " + param_min_path; + system(shell_command.c_str()); + + for (const auto &block : program->Blocks()) { + for (const auto &var_desc : block->Vars()) { + if (var_desc->Persistable()) { + if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { + continue; + } + std::string file_name = param_min_path + "/" + var_desc->Name(); + FILE *out_file = fopen(file_name.c_str(), "wb"); + char *origin_data = Get_binary_data(model_dir + "/" + var_desc->Name()); + char *data = origin_data; + LoadWithDumpForFloat32(*var_desc, &data, out_file); + delete origin_data; + fclose(out_file); + } + } + } +} + +int main(int argc, char **argv) { + const std::string kNoteEg = "( eg: ./quantify 1 your_combined_model_path output_path or ./quantify 0 your_seperated_model_path output_path or ./quantify 3 your_seperated_model_path output_path or ./quantify 2 your_seperated_model_path output_path)"; PADDLE_MOBILE_ENFORCE(argc > 1, "wee need params.%s ", kNoteEg.c_str()); std::string action_type = argv[1]; - PADDLE_MOBILE_ENFORCE(argc > 1 && (action_type) == "1" || action_type == "0", - "only 1 or 2 supported, current is %s %s ", + PADDLE_MOBILE_ENFORCE(argc > 1 && (action_type) == "0" || action_type == "1" || action_type == "2" || action_type == "3", + "only 0, 1, 2 or 3 supported, current is %s %s ", action_type.c_str(), kNoteEg.c_str()); @@ -251,7 +403,7 @@ int main(int argc, char **argv) { if (action_type == "0") { // for seperated const std::string &seperated_min_dir = output_path; - quantificate_seperated(base_path, seperated_min_dir); + quantificate_seperated_int8(base_path, seperated_min_dir); return 0; } @@ -260,16 +412,26 @@ int main(int argc, char **argv) { const std::string &combined_min_dir = output_path; std::string model_path = base_path + "/model"; std::string param_path = base_path + "/params"; - quantificate_combined(model_path, param_path, combined_min_dir); + quantificate_combined_int8(model_path, param_path, combined_min_dir); + return 0; + } + + if (action_type == "2") { + // for seperated + const std::string &seperated_min_dir = output_path; + quantificate_seperated_float32(base_path, seperated_min_dir); + return 0; + } + if (action_type == "3") { + // for combined + const std::string &combined_min_dir = output_path; + std::string model_path = base_path + "/model"; + std::string param_path = base_path + "/params"; + quantificate_combined_float32(model_path, param_path, combined_min_dir); return 0; } return -1; } - - - - - -- GitLab