未验证 提交 29771f27 编写于 作者: Y Yuan Shuai 提交者: GitHub

[LITE][PASS] Remove reshape2 / squeeze2 for tf_mobilenetv1/v2 (#3773)

* [LITE][PASS] Add pass for removing uesless reshape2 / squeeze2. test=develop
上级 7f22d9f0
...@@ -26,6 +26,7 @@ USE_MIR_PASS(argument_type_display_pass); ...@@ -26,6 +26,7 @@ USE_MIR_PASS(argument_type_display_pass);
USE_MIR_PASS(runtime_context_assign_pass); USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(graph_visualize_pass); USE_MIR_PASS(graph_visualize_pass);
USE_MIR_PASS(remove_tf_redundant_ops_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_shuffle_channel_fuse_pass); USE_MIR_PASS(lite_shuffle_channel_fuse_pass);
......
...@@ -29,6 +29,7 @@ lite_cc_library(mir_passes ...@@ -29,6 +29,7 @@ lite_cc_library(mir_passes
elimination/identity_scale_eliminate_pass.cc elimination/identity_scale_eliminate_pass.cc
elimination/identity_dropout_eliminate_pass.cc elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc elimination/elementwise_mul_constant_eliminate_pass.cc
elimination/remove_tf_redundant_ops_pass.cc
static_kernel_pick_pass.cc static_kernel_pick_pass.cc
variable_place_inference_pass.cc variable_place_inference_pass.cc
type_target_cast_pass.cc type_target_cast_pass.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/elimination/remove_tf_redundant_ops_pass.h"
#include <set>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher.h"
#include "lite/model_parser/cpp/var_desc.h"
namespace paddle {
namespace lite {
namespace mir {
void RemoveTFRedundantOpsPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
RemoveSqueeze2Reshape2Pattern(graph);
RemoveReshape2Pattern(graph);
}
void RemoveTFRedundantOpsPass::RemoveReshape2Pattern(
const std::unique_ptr<SSAGraph>& graph) {
bool found = false;
Node* softmax_node{nullptr};
Node* reshape2_node{nullptr};
std::string reshape2_out_arg_name;
Node* fetch_node{nullptr};
std::string fetch_in_arg_name;
DDim softmax_out_dims;
DDim reshape2_out_dims;
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (op_node->AsStmt().picked_kernel().op_type() == "softmax") {
softmax_node = op_node;
} else if (op_node->AsStmt().picked_kernel().op_type() == "reshape2") {
reshape2_node = op_node;
} else if (op_node->AsStmt().picked_kernel().op_type() == "fetch") {
fetch_node = op_node;
fetch_in_arg_name = fetch_node->inlinks.front()->AsArg().name;
}
}
if (softmax_node == nullptr || reshape2_node == nullptr) {
return;
}
// Get out tensor dims of softmax, reshape2
auto* scope = softmax_node->AsStmt().op()->scope();
auto softmax_out_arg_name = softmax_node->outlinks.front()->AsArg().name;
auto softmax_out_tensor =
scope->FindVar(softmax_out_arg_name)->Get<lite::Tensor>();
softmax_out_dims = softmax_out_tensor.dims();
for (auto out_node : reshape2_node->outlinks) {
if (out_node->IsArg() && out_node->outlinks.size() != 0) {
reshape2_out_arg_name = reshape2_node->outlinks.front()->AsArg().name;
auto reshape2_out_tensor =
scope->FindVar(reshape2_out_arg_name)->Get<lite::Tensor>();
reshape2_out_dims = reshape2_out_tensor.dims();
}
}
VLOG(3) << "reshape2_out_dims:" << reshape2_out_dims;
VLOG(3) << "softmax_out_dims:" << softmax_out_dims;
VLOG(3) << "found:" << found;
if (softmax_out_dims == reshape2_out_dims &&
softmax_node->outlinks.front() == reshape2_node->inlinks.front() &&
reshape2_out_arg_name == fetch_in_arg_name) {
found = true;
}
if (found) {
// link out_arg to op
IR_NODE_LINK_TO(softmax_node->outlinks.front(), fetch_node);
// collect nodes to safe remove
std::set<const Node*> nodes_to_remove;
auto remove_inst_node_and_out_args_node = [&](Node* n) {
nodes_to_remove.insert(n);
for (auto& out : n->outlinks) {
nodes_to_remove.insert(out);
}
};
remove_inst_node_and_out_args_node(reshape2_node);
GraphSafeRemoveNodes(graph.get(), nodes_to_remove);
auto fetch_op_desc = fetch_node->AsStmt().mutable_op_info();
fetch_op_desc->SetInput("X",
{softmax_node->outlinks.front()->AsArg().name});
}
VLOG(5) << "\n" << Visualize(graph.get());
}
void RemoveTFRedundantOpsPass::RemoveSqueeze2Reshape2Pattern(
const std::unique_ptr<SSAGraph>& graph) {
VLOG(5) << Visualize(graph.get());
bool found = false;
// find out_arg->squeeze2
// find out_arg_dims of out_arg
Node* out_arg_node{nullptr};
DDim out_arg_dims;
Node* squeeze2_node{nullptr};
// find squeeze2->reshape2
// find output dims of squeeze2 and reshape2 nodes
DDim squeeze2_out_dims;
Node* reshape2_node{nullptr};
Node* reshape2_out_node{nullptr};
DDim reshape2_out_dims;
// find next inst node of reshape2
Node* next_inst_node_of_reshape2_out{nullptr};
for (auto& node : graph->StmtTopologicalOrder()) {
if (node->AsStmt().picked_kernel().op_type() != "squeeze2") continue;
auto* scope = node->AsStmt().op()->scope();
// find inlinks of squeeze2: out_arg_node
squeeze2_node = node;
auto squeeze2_inlinks = squeeze2_node->inlinks;
VLOG(5) << "squeeze2_inlinks.size():" << squeeze2_inlinks.size();
for (auto& in_link : squeeze2_inlinks) {
if (in_link->IsArg() && squeeze2_inlinks.size() == 1) {
out_arg_node = in_link;
auto* var = scope->FindVar(out_arg_node->AsArg().name);
out_arg_dims = var->Get<lite::Tensor>().dims();
VLOG(5) << "arg name:" << out_arg_node->AsArg().name
<< " dims:" << out_arg_dims;
} else {
// found mutli-input links
continue;
}
}
// find squeeze2->reshape2 pattern
// and output dims of squeeze2, reshape2 nodes
auto squeeze2_outlinks = squeeze2_node->outlinks;
for (auto& squeeze2_out_link : squeeze2_outlinks) {
if (squeeze2_out_link->IsArg() &&
squeeze2_out_link->outlinks.size() != 0) {
auto* squeeze2_out_var =
scope->FindVar(squeeze2_out_link->AsArg().name);
squeeze2_out_dims = squeeze2_out_var->Get<lite::Tensor>().dims();
VLOG(5) << "squeeze2_out_arg.name:" << squeeze2_out_link->AsArg().name
<< " squeeze2_out_dims:" << squeeze2_out_dims
<< " squeeze2_out_link->outlinks.size():"
<< squeeze2_out_link->outlinks.size();
for (auto& out2_link : squeeze2_out_link->outlinks) {
if (out2_link->IsStmt() &&
out2_link->AsStmt().picked_kernel().op_type() == "reshape2") {
reshape2_node = out2_link;
for (auto& reshape2_out_link : reshape2_node->outlinks) {
if (reshape2_out_link->IsArg() &&
reshape2_out_link->outlinks.size() != 0) {
reshape2_out_node = reshape2_out_link;
auto* reshape2_out_var =
scope->FindVar(reshape2_out_link->AsArg().name);
reshape2_out_dims =
reshape2_out_var->Get<lite::Tensor>().dims();
VLOG(5) << "reshape2_out_node:" << reshape2_out_node
<< " reshape2_out_name:"
<< reshape2_out_link->AsArg().name
<< " reshape2_out_dims:" << reshape2_out_dims;
}
}
}
}
}
}
// find next inst node of reshape2
VLOG(5) << "reshape2_out_node->outlinks.size():"
<< reshape2_out_node->outlinks.size()
<< " reshape2_out_node->IsStmt():" << reshape2_out_node->IsStmt();
VLOG(5) << "reshape2_out_node->AsArg().name:"
<< reshape2_out_node->AsArg().name;
if (reshape2_out_node != nullptr &&
reshape2_out_node->outlinks.size() == 1 &&
reshape2_out_node->outlinks.front()->IsStmt()) {
next_inst_node_of_reshape2_out = reshape2_out_node->outlinks.front();
found = true;
break;
VLOG(5)
<< "next_inst_node_of_reshape2_out->picked_kernel().op_type():"
<< next_inst_node_of_reshape2_out->AsStmt().picked_kernel().op_type();
}
VLOG(5) << "==============================";
VLOG(5) << "out_arg_dims:" << out_arg_dims;
VLOG(5) << "squeeze2_out_dims:" << squeeze2_out_dims;
VLOG(5) << "reshape2_out_dims:" << reshape2_out_dims;
VLOG(5) << "==============================";
}
// replace pattern
if (found && out_arg_dims[1] == squeeze2_out_dims[1] &&
out_arg_dims[1] == reshape2_out_dims[1] && out_arg_dims[1] == 1001 &&
out_arg_dims[2] == out_arg_dims[3] && out_arg_dims[2] == 1 &&
next_inst_node_of_reshape2_out->AsStmt().picked_kernel().op_type() ==
"softmax") {
// link out_arg to op
IR_NODE_LINK_TO(out_arg_node, next_inst_node_of_reshape2_out);
// collect nodes to safe remove
std::set<const Node*> nodes_to_remove;
auto remove_inst_node_and_out_args_node = [&](Node* n) {
nodes_to_remove.insert(n);
for (auto& out : n->outlinks) {
nodes_to_remove.insert(out);
}
};
remove_inst_node_and_out_args_node(squeeze2_node);
remove_inst_node_and_out_args_node(reshape2_node);
GraphSafeRemoveNodes(graph.get(), nodes_to_remove);
auto next_inst_op_desc =
next_inst_node_of_reshape2_out->AsStmt().mutable_op_info();
next_inst_op_desc->SetInput("X", {out_arg_node->AsArg().name});
VLOG(5) << Visualize(graph.get());
}
VLOG(5) << "replace pattern fininshed";
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(remove_tf_redundant_ops_pass,
paddle::lite::mir::RemoveTFRedundantOpsPass)
.BindTargets({TARGET(kOpenCL), TARGET(kARM)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <limits>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/tensor.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace mir {
/*
* mir::RemoveTFRedundantOpsPass remove reshape2->squeeze2 pattern
* and last reshape2 op for tensorflow mobilenetv1/v2.
*/
class RemoveTFRedundantOpsPass : public mir::StmtPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
void RemoveReshape2Pattern(const std::unique_ptr<SSAGraph>& graph);
void RemoveSqueeze2Reshape2Pattern(const std::unique_ptr<SSAGraph>& graph);
};
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -109,6 +109,7 @@ class Optimizer { ...@@ -109,6 +109,7 @@ class Optimizer {
"apu_subgraph_pass", "apu_subgraph_pass",
"rknpu_subgraph_pass", "rknpu_subgraph_pass",
"static_kernel_pick_pass", // pick original kernel from graph "static_kernel_pick_pass", // pick original kernel from graph
"remove_tf_redundant_ops_pass",
"variable_place_inference_pass", // inference arg/var's "variable_place_inference_pass", // inference arg/var's
// info(target/precision/layout/device) // info(target/precision/layout/device)
// using kernel info // using kernel info
...@@ -174,6 +175,55 @@ class Optimizer { ...@@ -174,6 +175,55 @@ class Optimizer {
const lite::Scope* exec_scope() const { return exec_scope_; } const lite::Scope* exec_scope() const { return exec_scope_; }
// Set shape(dims) infos of var descs to scope var.
// developer can write pass using input / output tensor dims of op.
//
// Example: If you have node `Node* softmax_node`,
// you can get dims of output tensor in passes:
//
// auto* scope = softmax_node->AsStmt().op()->scope();
// auto softmax_out_arg_name =
// softmax_node->outlinks.front()->AsArg().name;
// auto softmax_out_tensor =
// scope->FindVar(softmax_out_arg_name)->Get<lite::Tensor>();
// softmax_out_dims = softmax_out_tensor.dims();
void SetVarDescShapeToScopeVar() {
auto dims_to_str_func = [](std::vector<int64_t> shape) -> std::string {
std::string str_res;
for (size_t i = 0; i < shape.size(); ++i) {
str_res += std::to_string(shape[i]);
if (i != shape.size() - 1) {
str_res += "x";
}
}
return str_res;
};
auto* program_desc = program_->program_desc();
VLOG(5) << "program_desc->BlocksSize():" << program_desc->BlocksSize();
auto blocks_desc = program_desc->GetBlocks();
for (size_t bidx = 0; bidx < blocks_desc.size(); ++bidx) {
auto block_desc = blocks_desc[bidx];
auto vars_desc = block_desc.GetVars();
for (size_t vidx = 0; vidx < vars_desc.size(); ++vidx) {
auto var_desc = vars_desc[vidx];
VLOG(5) << var_desc.Name() << " "
<< dims_to_str_func(var_desc.GetShape());
if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue;
auto* var = program_->exec_scope()->FindVar(var_desc.Name());
auto tensor = var->GetMutable<lite::Tensor>();
if (tensor->dims().size() == 0 && var_desc.GetShape().size() != 0) {
VLOG(5) << "var_desc.Name():" << var_desc.Name()
<< " shape:" << dims_to_str_func(var_desc.GetShape());
tensor->Resize(var_desc.GetShape());
}
VLOG(5) << "var_desc.Name():" << var_desc.Name()
<< " shape:" << dims_to_str_func(var_desc.GetShape())
<< " tensor:" << tensor->dims();
}
}
}
// Generate a new program based on the mir graph. // Generate a new program based on the mir graph.
std::unique_ptr<RuntimeProgram> GenRuntimeProgram() { std::unique_ptr<RuntimeProgram> GenRuntimeProgram() {
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>( auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
...@@ -214,6 +264,7 @@ class Optimizer { ...@@ -214,6 +264,7 @@ class Optimizer {
// Specify the passes and run them. // Specify the passes and run them.
void RunPasses(const std::vector<std::string>& passes) { void RunPasses(const std::vector<std::string>& passes) {
SetVarDescShapeToScopeVar();
for (auto& x : passes) { for (auto& x : passes) {
LOG(INFO) << "== Running pass: " << x; LOG(INFO) << "== Running pass: " << x;
mir::Pass* pass = mir::PassManager::Global().LookUp(x); mir::Pass* pass = mir::PassManager::Global().LookUp(x);
......
...@@ -71,6 +71,8 @@ struct Program { ...@@ -71,6 +71,8 @@ struct Program {
lite::Scope* exec_scope() { return exec_scope_; } lite::Scope* exec_scope() { return exec_scope_; }
lite::Scope* scope() { return scope_.get(); } lite::Scope* scope() { return scope_.get(); }
cpp::ProgramDesc* program_desc() { return &desc_; }
const std::map<std::string, PrecisionType>& var_data_type() const { const std::map<std::string, PrecisionType>& var_data_type() const {
return var_data_type_; return var_data_type_;
} }
......
...@@ -45,6 +45,8 @@ class BlockDesc : public BlockDescAPI { ...@@ -45,6 +45,8 @@ class BlockDesc : public BlockDescAPI {
template <typename T> template <typename T>
T* GetVar(int32_t idx); T* GetVar(int32_t idx);
std::vector<VarDesc>& GetVars() { return vars_; }
template <typename T> template <typename T>
T* AddVar(); T* AddVar();
......
...@@ -36,6 +36,8 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -36,6 +36,8 @@ class ProgramDesc : public ProgramDescAPI {
template <typename T> template <typename T>
T* GetBlock(int32_t idx); T* GetBlock(int32_t idx);
std::vector<BlockDesc>& GetBlocks() { return blocks_; }
template <typename T> template <typename T>
T* AddBlock(); T* AddBlock();
......
...@@ -34,7 +34,7 @@ DEFINE_int32(power_mode, ...@@ -34,7 +34,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times"); DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, true, "do all tests"); DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(check_result, true, "check the result"); DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(batch, 1, "batch size"); DEFINE_int32(batch, 1, "batch size");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册