提交 d9487155 编写于 作者: Z zhaocai 提交者: jackzhang235

save offline model

上级 53b6e51b
...@@ -19,9 +19,11 @@ ...@@ -19,9 +19,11 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/kernels/mlu/bridges/tensor.h" #include "lite/kernels/mlu/bridges/tensor.h"
#include "lite/utils/env.h"
#define PRINT_HW_TIME false #define PRINT_HW_TIME false
...@@ -96,7 +98,14 @@ class Graph { ...@@ -96,7 +98,14 @@ class Graph {
std::vector<std::shared_ptr<MLUTensor>>* MutableOutputs() { std::vector<std::shared_ptr<MLUTensor>>* MutableOutputs() {
return &output_tensors_; return &output_tensors_;
} }
void GenOfflineModel(const std::string& name) {
cnmlModel_t model;
std::string filename = name + ".offline.cambricon";
CNML_CALL(cnmlCreateModel(&model, name.c_str()));
CNML_CALL(cnmlAddFusionOpToModel(model, fusion_op_, filename.c_str()));
CNML_CALL(cnmlSaveModel(model, filename.c_str()));
CNML_CALL(cnmlDestroyModel(model));
}
void FuseOp(cnmlBaseOp_t op) { CNML_CALL(cnmlFuseOp(op, fusion_op_)); } void FuseOp(cnmlBaseOp_t op) { CNML_CALL(cnmlFuseOp(op, fusion_op_)); }
void Compile(cnmlCoreVersion_t core_version, int core_number) { void Compile(cnmlCoreVersion_t core_version, int core_number) {
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
#pragma once #pragma once
#include <algorithm>
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "lite/api/paddle_place.h" #include "lite/api/paddle_place.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
...@@ -156,9 +158,60 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -156,9 +158,60 @@ class SubgraphEngine : public subgraph::Engine {
auto core_number = mlu_context.MLUCoreNumber(); auto core_number = mlu_context.MLUCoreNumber();
graph->Compile(core_version, core_number); graph->Compile(core_version, core_number);
shape_graph_map_[new_shape] = graph; shape_graph_map_[new_shape] = graph;
if (GetBoolFromEnv("SAVE_MLU_OFFLINE_MODEL")) {
graph->GenOfflineModel(GetOfflineModName());
}
return status; return status;
} }
std::string TrimStrings(std::string origin_str) {
std::string str = origin_str;
std::size_t found = str.find("0x");
std::size_t found_end = 0;
std::vector<std::string> del_strs = {
"/trans_io_copy", "/trans_cast", "/trans_layout"};
for (auto iterm : del_strs) {
found_end = str.find(iterm);
// trim point address and one of the del_strs
if (found != std::string::npos && found_end != std::string::npos) {
str.replace(found, found_end - found, "");
found_end = str.find(iterm);
str.replace(found_end, iterm.size(), "");
break;
}
}
return str;
}
std::string GetOfflineModName() {
sort(input_names_.begin(), input_names_.end());
sort(output_names_.begin(), output_names_.end());
std::string name = "";
std::string delimiter = "__";
std::string delimiter_num = "_";
std::string tmp = "";
for (auto input_name : input_names_) {
tmp = input_name;
name += TrimStrings(tmp) + delimiter + "input_shape_";
auto input_tensor = scope_->FindMutableTensor(input_name);
for (auto iterm : input_tensor->dims().Vectorize()) {
name += std::to_string(iterm) + delimiter_num;
}
name += delimiter;
}
for (auto output_name : output_names_) {
tmp = output_name;
name += TrimStrings(tmp) + delimiter + "output_shape_";
auto output_tensor = scope_->FindMutableTensor(output_name);
for (auto iterm : output_tensor->dims().Vectorize()) {
name += std::to_string(iterm) + delimiter_num;
}
name += delimiter;
}
std::replace(name.begin(), name.end(), '/', '-');
return name;
}
int LaunchDeviceProgram() override { int LaunchDeviceProgram() override {
// prepare input and output memory // prepare input and output memory
auto graph = shape_graph_map_[inputs_shape_]; auto graph = shape_graph_map_[inputs_shape_];
...@@ -226,7 +279,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -226,7 +279,7 @@ class SubgraphEngine : public subgraph::Engine {
std::map<std::vector<std::vector<int64_t>>, std::map<std::vector<std::vector<int64_t>>,
std::shared_ptr<paddle::lite::subgraph::mlu::Graph>> std::shared_ptr<paddle::lite::subgraph::mlu::Graph>>
shape_graph_map_{}; shape_graph_map_{};
}; }; // namespace mlu
template <PrecisionType Precision> template <PrecisionType Precision>
class SubgraphCompute class SubgraphCompute
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册