提交 d9487155 编写于 作者: Z zhaocai 提交者: jackzhang235

save offline model

上级 53b6e51b
......@@ -19,9 +19,11 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/tensor.h"
#include "lite/kernels/mlu/bridges/tensor.h"
#include "lite/utils/env.h"
#define PRINT_HW_TIME false
......@@ -96,7 +98,14 @@ class Graph {
std::vector<std::shared_ptr<MLUTensor>>* MutableOutputs() {
return &output_tensors_;
}
void GenOfflineModel(const std::string& name) {
cnmlModel_t model;
std::string filename = name + ".offline.cambricon";
CNML_CALL(cnmlCreateModel(&model, name.c_str()));
CNML_CALL(cnmlAddFusionOpToModel(model, fusion_op_, filename.c_str()));
CNML_CALL(cnmlSaveModel(model, filename.c_str()));
CNML_CALL(cnmlDestroyModel(model));
}
void FuseOp(cnmlBaseOp_t op) { CNML_CALL(cnmlFuseOp(op, fusion_op_)); }
void Compile(cnmlCoreVersion_t core_version, int core_number) {
......
......@@ -14,10 +14,12 @@
#pragma once
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
......@@ -156,9 +158,60 @@ class SubgraphEngine : public subgraph::Engine {
auto core_number = mlu_context.MLUCoreNumber();
graph->Compile(core_version, core_number);
shape_graph_map_[new_shape] = graph;
if (GetBoolFromEnv("SAVE_MLU_OFFLINE_MODEL")) {
graph->GenOfflineModel(GetOfflineModName());
}
return status;
}
std::string TrimStrings(std::string origin_str) {
std::string str = origin_str;
std::size_t found = str.find("0x");
std::size_t found_end = 0;
std::vector<std::string> del_strs = {
"/trans_io_copy", "/trans_cast", "/trans_layout"};
for (auto iterm : del_strs) {
found_end = str.find(iterm);
// trim point address and one of the del_strs
if (found != std::string::npos && found_end != std::string::npos) {
str.replace(found, found_end - found, "");
found_end = str.find(iterm);
str.replace(found_end, iterm.size(), "");
break;
}
}
return str;
}
std::string GetOfflineModName() {
sort(input_names_.begin(), input_names_.end());
sort(output_names_.begin(), output_names_.end());
std::string name = "";
std::string delimiter = "__";
std::string delimiter_num = "_";
std::string tmp = "";
for (auto input_name : input_names_) {
tmp = input_name;
name += TrimStrings(tmp) + delimiter + "input_shape_";
auto input_tensor = scope_->FindMutableTensor(input_name);
for (auto iterm : input_tensor->dims().Vectorize()) {
name += std::to_string(iterm) + delimiter_num;
}
name += delimiter;
}
for (auto output_name : output_names_) {
tmp = output_name;
name += TrimStrings(tmp) + delimiter + "output_shape_";
auto output_tensor = scope_->FindMutableTensor(output_name);
for (auto iterm : output_tensor->dims().Vectorize()) {
name += std::to_string(iterm) + delimiter_num;
}
name += delimiter;
}
std::replace(name.begin(), name.end(), '/', '-');
return name;
}
int LaunchDeviceProgram() override {
// prepare input and output memory
auto graph = shape_graph_map_[inputs_shape_];
......@@ -226,7 +279,7 @@ class SubgraphEngine : public subgraph::Engine {
std::map<std::vector<std::vector<int64_t>>,
std::shared_ptr<paddle::lite::subgraph::mlu::Graph>>
shape_graph_map_{};
};
}; // namespace mlu
template <PrecisionType Precision>
class SubgraphCompute
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册