提交 3168be11 编写于 作者: Y Yanzhan Yang 提交者: GitHub

enhance quantification tool to dump float32 params. (#1731)

上级 a114444f
......@@ -68,7 +68,7 @@ std::shared_ptr<ProgramDesc> loadParams(const std::string &model_path) {
}
void LoadWithDump(const paddle_mobile::framework::VarDesc &var_desc, char **dataP, FILE *out_file) {
void LoadWithDumpForInt8(const paddle_mobile::framework::VarDesc &var_desc, char **dataP, FILE *out_file) {
// 1. version
uint32_t version = *reinterpret_cast<uint32_t *>(*dataP);
......@@ -182,8 +182,7 @@ void LoadWithDump(const paddle_mobile::framework::VarDesc &var_desc, char **data
}
void
quantificate_combined(const std::string &model_path, const std::string &param_path, const std::string &param_min_path) {
quantificate_combined_int8(const std::string &model_path, const std::string &param_path, const std::string &param_min_path) {
auto program = loadParams(model_path);
char *origin_data = Get_binary_data(param_path);
char *data = origin_data;
......@@ -194,17 +193,15 @@ quantificate_combined(const std::string &model_path, const std::string &param_pa
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
continue;
}
LoadWithDump(*var_desc, &data, out_file);
LoadWithDumpForInt8(*var_desc, &data, out_file);
}
}
}
fclose(out_file);
delete origin_data;
}
void quantificate_seperated(const std::string model_dir, const std::string param_min_path) {
void quantificate_seperated_int8(const std::string model_dir, const std::string param_min_path) {
auto program = loadParams(model_dir + "/__model__");
std::string shell_command = "mkdir " + param_min_path;
......@@ -220,25 +217,180 @@ void quantificate_seperated(const std::string model_dir, const std::string param
FILE *out_file = fopen(file_name.c_str(), "wb");
char *origin_data = Get_binary_data(model_dir + "/" + var_desc->Name());
char *data = origin_data;
LoadWithDump(*var_desc, &data, out_file);
LoadWithDumpForInt8(*var_desc, &data, out_file);
delete origin_data;
fclose(out_file);
}
}
}
}
void LoadWithDumpForFloat32(const paddle_mobile::framework::VarDesc &var_desc, char **dataP, FILE *out_file) {
// 1. version
uint32_t version = *reinterpret_cast<uint32_t *>(*dataP);
// write version
fwrite(&version, kSize32, 1, out_file);
*dataP += kSize32;
// 2 Lod information
auto *lod_level_ptr = new uint64_t();
memcpy(lod_level_ptr, *dataP, kSize64);
uint64_t lod_level = 0;
// write lod Information
fwrite(&lod_level, kSize64, 1, out_file);
delete lod_level_ptr;
*dataP += kSize64;
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size = *reinterpret_cast<uint64_t *>(*dataP);
// write lod size
fwrite(&size, kSize64, 1, out_file);
(*dataP) += kSize64;
std::vector<size_t> tmp(size / sizeof(size_t));
for (unsigned long &k : tmp) {
k = *reinterpret_cast<size_t *>(*dataP);
(*dataP) += sizeof(size_t);
}
// write lod size vector
fwrite(&tmp, sizeof(size_t), tmp.size(), out_file);
}
// 3. tensor version
uint32_t tensor_version = *reinterpret_cast<uint32_t *>(*dataP);
// write tensor version
fwrite(&tensor_version, kSize32, 1, out_file);
(*dataP) += kSize32;
// 4. tensor desc
int32_t size = *reinterpret_cast<int32_t *>(*dataP);
// write tensor desc
fwrite(&size, sizeof(int32_t), 1, out_file);
(*dataP) += sizeof(int32_t);
std::unique_ptr<char[]> buf(new char[size]);
for (int m = 0; m < size; ++m) {
buf.get()[m] = (*dataP)[m];
}
fwrite(buf.get(), sizeof(char), static_cast<size_t>(size), out_file);
(*dataP) += (sizeof(char) * size);
const paddle_mobile::framework::TensorDesc &desc = var_desc.Tensor_desc();
int memory_size = 1;
for (auto l : desc.Dims()) {
memory_size *= l;
}
void *memory = nullptr;
int type_size = 0;
switch (desc.DataType()) {
case paddle_mobile::framework::VARTYPE_TYPE_FP16:
type_size = 2;
break;
case paddle_mobile::framework::VARTYPE_TYPE_FP32:
type_size = 4;
break;
case paddle_mobile::framework::VARTYPE_TYPE_FP64:
type_size = 8;
break;
case paddle_mobile::framework::VARTYPE_TYPE_INT32:
type_size = 4;
break;
case paddle_mobile::framework::VARTYPE_TYPE_INT64:
type_size = 8;
break;
case paddle_mobile::framework::VARTYPE_TYPE_BOOL:
type_size = 1;
break;
default:
break;
}
size_t tensorSize = sizeof(char) * memory_size * type_size;
memory = new char[tensorSize];
for (int n = 0; n < tensorSize; ++n) {
static_cast<char *>(memory)[n] = (*dataP)[n];
}
*dataP += tensorSize;
// for float 32
float min_value = std::numeric_limits<float>::max();
float max_value = std::numeric_limits<float>::min();
for (int k = 0; k < memory_size; ++k) {
min_value = std::min(min_value, static_cast<float *> (memory)[k]);
max_value = std::max(max_value, static_cast<float *> (memory)[k]);
}
float diff = 0.0;
for (int g = 0; g < memory_size; ++g) {
float value = static_cast<float *> (memory)[g];
auto factor = (uint8_t) round((value - min_value) / (max_value - min_value) * 255);
float value_quantized = min_value + (factor / 255.0) * (max_value - min_value);
diff += abs(value - value_quantized);
fwrite(&value_quantized, sizeof(float), 1, out_file);
}
std::cout << "avg diff caused by quantization for var " << var_desc.Name() << " is: " << diff << std::endl;
}
void
quantificate_combined_float32(const std::string &model_path, const std::string &param_path, const std::string &param_min_path) {
auto program = loadParams(model_path);
char *origin_data = Get_binary_data(param_path);
char *data = origin_data;
FILE *out_file = fopen(param_min_path.c_str(), "wb");
for (const auto &block : program->Blocks()) {
for (const auto &var_desc : block->Vars()) {
if (var_desc->Persistable()) {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
continue;
}
LoadWithDumpForFloat32(*var_desc, &data, out_file);
}
}
}
fclose(out_file);
delete origin_data;
}
int main(int argc, char **argv) {
void quantificate_seperated_float32(const std::string model_dir, const std::string param_min_path) {
auto program = loadParams(model_dir + "/__model__");
const std::string kNoteEg = "( eg: ./quantify 1 your_combined_model_path output_path or ./quantify 0 your_seperated_model_path output_path)";
std::string shell_command = "mkdir " + param_min_path;
system(shell_command.c_str());
for (const auto &block : program->Blocks()) {
for (const auto &var_desc : block->Vars()) {
if (var_desc->Persistable()) {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
continue;
}
std::string file_name = param_min_path + "/" + var_desc->Name();
FILE *out_file = fopen(file_name.c_str(), "wb");
char *origin_data = Get_binary_data(model_dir + "/" + var_desc->Name());
char *data = origin_data;
LoadWithDumpForFloat32(*var_desc, &data, out_file);
delete origin_data;
fclose(out_file);
}
}
}
}
int main(int argc, char **argv) {
const std::string kNoteEg = "( eg: ./quantify 1 your_combined_model_path output_path or ./quantify 0 your_seperated_model_path output_path or ./quantify 3 your_seperated_model_path output_path or ./quantify 2 your_seperated_model_path output_path)";
PADDLE_MOBILE_ENFORCE(argc > 1, "wee need params.%s ", kNoteEg.c_str());
std::string action_type = argv[1];
PADDLE_MOBILE_ENFORCE(argc > 1 && (action_type) == "1" || action_type == "0",
"only 1 or 2 supported, current is %s %s ",
PADDLE_MOBILE_ENFORCE(argc > 1 && (action_type) == "0" || action_type == "1" || action_type == "2" || action_type == "3",
"only 0, 1, 2 or 3 supported, current is %s %s ",
action_type.c_str(),
kNoteEg.c_str());
......@@ -251,7 +403,7 @@ int main(int argc, char **argv) {
if (action_type == "0") {
// for seperated
const std::string &seperated_min_dir = output_path;
quantificate_seperated(base_path, seperated_min_dir);
quantificate_seperated_int8(base_path, seperated_min_dir);
return 0;
}
......@@ -260,16 +412,26 @@ int main(int argc, char **argv) {
const std::string &combined_min_dir = output_path;
std::string model_path = base_path + "/model";
std::string param_path = base_path + "/params";
quantificate_combined(model_path, param_path, combined_min_dir);
quantificate_combined_int8(model_path, param_path, combined_min_dir);
return 0;
}
if (action_type == "2") {
// for seperated
const std::string &seperated_min_dir = output_path;
quantificate_seperated_float32(base_path, seperated_min_dir);
return 0;
}
if (action_type == "3") {
// for combined
const std::string &combined_min_dir = output_path;
std::string model_path = base_path + "/model";
std::string param_path = base_path + "/params";
quantificate_combined_float32(model_path, param_path, combined_min_dir);
return 0;
}
return -1;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册