提交 5f4c9179 编写于 作者: Z zhaocai 提交者: jackzhang235

modify strigs

上级 d9487155
...@@ -100,7 +100,7 @@ class Graph { ...@@ -100,7 +100,7 @@ class Graph {
} }
void GenOfflineModel(const std::string& name) { void GenOfflineModel(const std::string& name) {
cnmlModel_t model; cnmlModel_t model;
std::string filename = name + ".offline.cambricon"; const auto& filename = name + ".offline.cambricon";
CNML_CALL(cnmlCreateModel(&model, name.c_str())); CNML_CALL(cnmlCreateModel(&model, name.c_str()));
CNML_CALL(cnmlAddFusionOpToModel(model, fusion_op_, filename.c_str())); CNML_CALL(cnmlAddFusionOpToModel(model, fusion_op_, filename.c_str()));
CNML_CALL(cnmlSaveModel(model, filename.c_str())); CNML_CALL(cnmlSaveModel(model, filename.c_str()));
......
...@@ -164,13 +164,13 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -164,13 +164,13 @@ class SubgraphEngine : public subgraph::Engine {
return status; return status;
} }
std::string TrimStrings(std::string origin_str) { std::string TrimStrings(const std::string& origin_str) {
std::string str = origin_str; std::string str = origin_str;
std::size_t found = str.find("0x"); std::size_t found = str.find("0x");
std::size_t found_end = 0; std::size_t found_end = 0;
std::vector<std::string> del_strs = { const std::vector<std::string> del_strs = {
"/trans_io_copy", "/trans_cast", "/trans_layout"}; "/trans_io_copy", "/trans_cast", "/trans_layout"};
for (auto iterm : del_strs) { for (const auto& iterm : del_strs) {
found_end = str.find(iterm); found_end = str.find(iterm);
// trim point address and one of the del_strs // trim point address and one of the del_strs
if (found != std::string::npos && found_end != std::string::npos) { if (found != std::string::npos && found_end != std::string::npos) {
...@@ -186,24 +186,26 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -186,24 +186,26 @@ class SubgraphEngine : public subgraph::Engine {
std::string GetOfflineModName() { std::string GetOfflineModName() {
sort(input_names_.begin(), input_names_.end()); sort(input_names_.begin(), input_names_.end());
sort(output_names_.begin(), output_names_.end()); sort(output_names_.begin(), output_names_.end());
const auto& delimiter = "__";
const auto& delimiter_num = "_";
const auto& input_shape_str = "input_shape_";
const auto& output_shape_str = "output_shape_";
std::string name = ""; std::string name = "";
std::string delimiter = "__";
std::string delimiter_num = "_";
std::string tmp = ""; std::string tmp = "";
for (auto input_name : input_names_) { for (const auto& input_name : input_names_) {
tmp = input_name; tmp = input_name;
name += TrimStrings(tmp) + delimiter + "input_shape_"; name += TrimStrings(tmp) + delimiter + input_shape_str;
auto input_tensor = scope_->FindMutableTensor(input_name); auto input_tensor = scope_->FindMutableTensor(input_name);
for (auto iterm : input_tensor->dims().Vectorize()) { for (const auto& iterm : input_tensor->dims().Vectorize()) {
name += std::to_string(iterm) + delimiter_num; name += std::to_string(iterm) + delimiter_num;
} }
name += delimiter; name += delimiter;
} }
for (auto output_name : output_names_) { for (const auto& output_name : output_names_) {
tmp = output_name; tmp = output_name;
name += TrimStrings(tmp) + delimiter + "output_shape_"; name += TrimStrings(tmp) + delimiter + output_shape_str;
auto output_tensor = scope_->FindMutableTensor(output_name); auto output_tensor = scope_->FindMutableTensor(output_name);
for (auto iterm : output_tensor->dims().Vectorize()) { for (const auto& iterm : output_tensor->dims().Vectorize()) {
name += std::to_string(iterm) + delimiter_num; name += std::to_string(iterm) + delimiter_num;
} }
name += delimiter; name += delimiter;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册