diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 6732b968734631cf74c1e8fc7b825f3e0b89b9fe..485bd10770d6e5a29963f336dfdf6d47302ccbc0 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -26,6 +26,7 @@ USE_MIR_PASS(argument_type_display_pass); USE_MIR_PASS(runtime_context_assign_pass); USE_MIR_PASS(graph_visualize_pass); +USE_MIR_PASS(remove_tf_redundant_ops_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_shuffle_channel_fuse_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index b8234b18922f454c41e295209da13de024184adc..2540bb56d4082570c984e8eea009b5575825fec9 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -29,6 +29,7 @@ lite_cc_library(mir_passes elimination/identity_scale_eliminate_pass.cc elimination/identity_dropout_eliminate_pass.cc elimination/elementwise_mul_constant_eliminate_pass.cc + elimination/remove_tf_redundant_ops_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_cast_pass.cc diff --git a/lite/core/mir/elimination/remove_tf_redundant_ops_pass.cc b/lite/core/mir/elimination/remove_tf_redundant_ops_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..f4226820d0437db8cad0cfdac92be15359bb90bd --- /dev/null +++ b/lite/core/mir/elimination/remove_tf_redundant_ops_pass.cc @@ -0,0 +1,245 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/elimination/remove_tf_redundant_ops_pass.h" +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher.h" +#include "lite/model_parser/cpp/var_desc.h" + +namespace paddle { +namespace lite { +namespace mir { + +void RemoveTFRedundantOpsPass::Apply(const std::unique_ptr& graph) { + RemoveSqueeze2Reshape2Pattern(graph); + RemoveReshape2Pattern(graph); +} + +void RemoveTFRedundantOpsPass::RemoveReshape2Pattern( + const std::unique_ptr& graph) { + bool found = false; + Node* softmax_node{nullptr}; + Node* reshape2_node{nullptr}; + std::string reshape2_out_arg_name; + Node* fetch_node{nullptr}; + std::string fetch_in_arg_name; + DDim softmax_out_dims; + DDim reshape2_out_dims; + + for (auto& op_node : graph->StmtTopologicalOrder()) { + if (op_node->AsStmt().picked_kernel().op_type() == "softmax") { + softmax_node = op_node; + } else if (op_node->AsStmt().picked_kernel().op_type() == "reshape2") { + reshape2_node = op_node; + } else if (op_node->AsStmt().picked_kernel().op_type() == "fetch") { + fetch_node = op_node; + fetch_in_arg_name = fetch_node->inlinks.front()->AsArg().name; + } + } + + if (softmax_node == nullptr || reshape2_node == nullptr) { + return; + } + + // Get out tensor dims of softmax, reshape2 + auto* scope = softmax_node->AsStmt().op()->scope(); + auto softmax_out_arg_name = softmax_node->outlinks.front()->AsArg().name; + auto softmax_out_tensor = + scope->FindVar(softmax_out_arg_name)->Get(); + softmax_out_dims = softmax_out_tensor.dims(); + + for (auto out_node : reshape2_node->outlinks) { + if (out_node->IsArg() && out_node->outlinks.size() != 0) { + reshape2_out_arg_name = reshape2_node->outlinks.front()->AsArg().name; + auto reshape2_out_tensor = + scope->FindVar(reshape2_out_arg_name)->Get(); + reshape2_out_dims = reshape2_out_tensor.dims(); + } + } + + VLOG(3) << "reshape2_out_dims:" << reshape2_out_dims; + VLOG(3) << "softmax_out_dims:" << softmax_out_dims; + VLOG(3) << "found:" << found; + + if (softmax_out_dims == reshape2_out_dims && + softmax_node->outlinks.front() == reshape2_node->inlinks.front() && + reshape2_out_arg_name == fetch_in_arg_name) { + found = true; + } + + if (found) { + // link out_arg to op + IR_NODE_LINK_TO(softmax_node->outlinks.front(), fetch_node); + + // collect nodes to safe remove + std::set nodes_to_remove; + auto remove_inst_node_and_out_args_node = [&](Node* n) { + nodes_to_remove.insert(n); + for (auto& out : n->outlinks) { + nodes_to_remove.insert(out); + } + }; + + remove_inst_node_and_out_args_node(reshape2_node); + GraphSafeRemoveNodes(graph.get(), nodes_to_remove); + auto fetch_op_desc = fetch_node->AsStmt().mutable_op_info(); + fetch_op_desc->SetInput("X", + {softmax_node->outlinks.front()->AsArg().name}); + } + VLOG(5) << "\n" << Visualize(graph.get()); +} + +void RemoveTFRedundantOpsPass::RemoveSqueeze2Reshape2Pattern( + const std::unique_ptr& graph) { + VLOG(5) << Visualize(graph.get()); + bool found = false; + + // find out_arg->squeeze2 + // find out_arg_dims of out_arg + Node* out_arg_node{nullptr}; + DDim out_arg_dims; + Node* squeeze2_node{nullptr}; + + // find squeeze2->reshape2 + // find output dims of squeeze2 and reshape2 nodes + DDim squeeze2_out_dims; + Node* reshape2_node{nullptr}; + Node* reshape2_out_node{nullptr}; + DDim reshape2_out_dims; + + // find next inst node of reshape2 + Node* next_inst_node_of_reshape2_out{nullptr}; + + for (auto& node : graph->StmtTopologicalOrder()) { + if (node->AsStmt().picked_kernel().op_type() != "squeeze2") continue; + auto* scope = node->AsStmt().op()->scope(); + + // find inlinks of squeeze2: out_arg_node + squeeze2_node = node; + auto squeeze2_inlinks = squeeze2_node->inlinks; + VLOG(5) << "squeeze2_inlinks.size():" << squeeze2_inlinks.size(); + for (auto& in_link : squeeze2_inlinks) { + if (in_link->IsArg() && squeeze2_inlinks.size() == 1) { + out_arg_node = in_link; + auto* var = scope->FindVar(out_arg_node->AsArg().name); + out_arg_dims = var->Get().dims(); + VLOG(5) << "arg name:" << out_arg_node->AsArg().name + << " dims:" << out_arg_dims; + } else { + // found mutli-input links + continue; + } + } + + // find squeeze2->reshape2 pattern + // and output dims of squeeze2, reshape2 nodes + auto squeeze2_outlinks = squeeze2_node->outlinks; + for (auto& squeeze2_out_link : squeeze2_outlinks) { + if (squeeze2_out_link->IsArg() && + squeeze2_out_link->outlinks.size() != 0) { + auto* squeeze2_out_var = + scope->FindVar(squeeze2_out_link->AsArg().name); + squeeze2_out_dims = squeeze2_out_var->Get().dims(); + + VLOG(5) << "squeeze2_out_arg.name:" << squeeze2_out_link->AsArg().name + << " squeeze2_out_dims:" << squeeze2_out_dims + << " squeeze2_out_link->outlinks.size():" + << squeeze2_out_link->outlinks.size(); + + for (auto& out2_link : squeeze2_out_link->outlinks) { + if (out2_link->IsStmt() && + out2_link->AsStmt().picked_kernel().op_type() == "reshape2") { + reshape2_node = out2_link; + for (auto& reshape2_out_link : reshape2_node->outlinks) { + if (reshape2_out_link->IsArg() && + reshape2_out_link->outlinks.size() != 0) { + reshape2_out_node = reshape2_out_link; + auto* reshape2_out_var = + scope->FindVar(reshape2_out_link->AsArg().name); + reshape2_out_dims = + reshape2_out_var->Get().dims(); + + VLOG(5) << "reshape2_out_node:" << reshape2_out_node + << " reshape2_out_name:" + << reshape2_out_link->AsArg().name + << " reshape2_out_dims:" << reshape2_out_dims; + } + } + } + } + } + } + + // find next inst node of reshape2 + VLOG(5) << "reshape2_out_node->outlinks.size():" + << reshape2_out_node->outlinks.size() + << " reshape2_out_node->IsStmt():" << reshape2_out_node->IsStmt(); + VLOG(5) << "reshape2_out_node->AsArg().name:" + << reshape2_out_node->AsArg().name; + if (reshape2_out_node != nullptr && + reshape2_out_node->outlinks.size() == 1 && + reshape2_out_node->outlinks.front()->IsStmt()) { + next_inst_node_of_reshape2_out = reshape2_out_node->outlinks.front(); + found = true; + break; + VLOG(5) + << "next_inst_node_of_reshape2_out->picked_kernel().op_type():" + << next_inst_node_of_reshape2_out->AsStmt().picked_kernel().op_type(); + } + + VLOG(5) << "=============================="; + VLOG(5) << "out_arg_dims:" << out_arg_dims; + VLOG(5) << "squeeze2_out_dims:" << squeeze2_out_dims; + VLOG(5) << "reshape2_out_dims:" << reshape2_out_dims; + VLOG(5) << "=============================="; + } + + // replace pattern + if (found && out_arg_dims[1] == squeeze2_out_dims[1] && + out_arg_dims[1] == reshape2_out_dims[1] && out_arg_dims[1] == 1001 && + out_arg_dims[2] == out_arg_dims[3] && out_arg_dims[2] == 1 && + next_inst_node_of_reshape2_out->AsStmt().picked_kernel().op_type() == + "softmax") { + // link out_arg to op + IR_NODE_LINK_TO(out_arg_node, next_inst_node_of_reshape2_out); + + // collect nodes to safe remove + std::set nodes_to_remove; + auto remove_inst_node_and_out_args_node = [&](Node* n) { + nodes_to_remove.insert(n); + for (auto& out : n->outlinks) { + nodes_to_remove.insert(out); + } + }; + remove_inst_node_and_out_args_node(squeeze2_node); + remove_inst_node_and_out_args_node(reshape2_node); + GraphSafeRemoveNodes(graph.get(), nodes_to_remove); + auto next_inst_op_desc = + next_inst_node_of_reshape2_out->AsStmt().mutable_op_info(); + next_inst_op_desc->SetInput("X", {out_arg_node->AsArg().name}); + VLOG(5) << Visualize(graph.get()); + } + VLOG(5) << "replace pattern fininshed"; +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(remove_tf_redundant_ops_pass, + paddle::lite::mir::RemoveTFRedundantOpsPass) + .BindTargets({TARGET(kOpenCL), TARGET(kARM)}); diff --git a/lite/core/mir/elimination/remove_tf_redundant_ops_pass.h b/lite/core/mir/elimination/remove_tf_redundant_ops_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..652a8fb4a7f67e173527725e3bbecfadcde96798 --- /dev/null +++ b/lite/core/mir/elimination/remove_tf_redundant_ops_pass.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "lite/core/mir/pass.h" +#include "lite/core/tensor.h" +#include "lite/core/types.h" + +namespace paddle { +namespace lite { +namespace mir { + +/* + * mir::RemoveTFRedundantOpsPass remove reshape2->squeeze2 pattern + * and last reshape2 op for tensorflow mobilenetv1/v2. + */ +class RemoveTFRedundantOpsPass : public mir::StmtPass { + public: + void Apply(const std::unique_ptr& graph) override; + void RemoveReshape2Pattern(const std::unique_ptr& graph); + void RemoveSqueeze2Reshape2Pattern(const std::unique_ptr& graph); +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 05f801facdf9557da1e872d69fcde0bf3b321d2e..579f7690d7b73bb400d68cbcaf138b32bb23a6ce 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -108,7 +108,8 @@ class Optimizer { "bm_subgraph_pass", "apu_subgraph_pass", "rknpu_subgraph_pass", - "static_kernel_pick_pass", // pick original kernel from graph + "static_kernel_pick_pass", // pick original kernel from graph + "remove_tf_redundant_ops_pass", "variable_place_inference_pass", // inference arg/var's // info(target/precision/layout/device) // using kernel info @@ -174,6 +175,55 @@ class Optimizer { const lite::Scope* exec_scope() const { return exec_scope_; } + // Set shape(dims) infos of var descs to scope var. + // developer can write pass using input / output tensor dims of op. + // + // Example: If you have node `Node* softmax_node`, + // you can get dims of output tensor in passes: + // + // auto* scope = softmax_node->AsStmt().op()->scope(); + // auto softmax_out_arg_name = + // softmax_node->outlinks.front()->AsArg().name; + // auto softmax_out_tensor = + // scope->FindVar(softmax_out_arg_name)->Get(); + // softmax_out_dims = softmax_out_tensor.dims(); + void SetVarDescShapeToScopeVar() { + auto dims_to_str_func = [](std::vector shape) -> std::string { + std::string str_res; + for (size_t i = 0; i < shape.size(); ++i) { + str_res += std::to_string(shape[i]); + if (i != shape.size() - 1) { + str_res += "x"; + } + } + return str_res; + }; + + auto* program_desc = program_->program_desc(); + VLOG(5) << "program_desc->BlocksSize():" << program_desc->BlocksSize(); + auto blocks_desc = program_desc->GetBlocks(); + for (size_t bidx = 0; bidx < blocks_desc.size(); ++bidx) { + auto block_desc = blocks_desc[bidx]; + auto vars_desc = block_desc.GetVars(); + for (size_t vidx = 0; vidx < vars_desc.size(); ++vidx) { + auto var_desc = vars_desc[vidx]; + VLOG(5) << var_desc.Name() << " " + << dims_to_str_func(var_desc.GetShape()); + if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue; + auto* var = program_->exec_scope()->FindVar(var_desc.Name()); + auto tensor = var->GetMutable(); + if (tensor->dims().size() == 0 && var_desc.GetShape().size() != 0) { + VLOG(5) << "var_desc.Name():" << var_desc.Name() + << " shape:" << dims_to_str_func(var_desc.GetShape()); + tensor->Resize(var_desc.GetShape()); + } + VLOG(5) << "var_desc.Name():" << var_desc.Name() + << " shape:" << dims_to_str_func(var_desc.GetShape()) + << " tensor:" << tensor->dims(); + } + } + } + // Generate a new program based on the mir graph. std::unique_ptr GenRuntimeProgram() { auto pass = mir::PassManager::Global().LookUp( @@ -214,6 +264,7 @@ class Optimizer { // Specify the passes and run them. void RunPasses(const std::vector& passes) { + SetVarDescShapeToScopeVar(); for (auto& x : passes) { LOG(INFO) << "== Running pass: " << x; mir::Pass* pass = mir::PassManager::Global().LookUp(x); diff --git a/lite/core/program.h b/lite/core/program.h index 6fe65f158b8d547e7a741e329a192d2661a60060..c1cc6113d747a134a3ffbdc7291273db2b275ef8 100644 --- a/lite/core/program.h +++ b/lite/core/program.h @@ -71,6 +71,8 @@ struct Program { lite::Scope* exec_scope() { return exec_scope_; } lite::Scope* scope() { return scope_.get(); } + cpp::ProgramDesc* program_desc() { return &desc_; } + const std::map& var_data_type() const { return var_data_type_; } diff --git a/lite/model_parser/cpp/block_desc.h b/lite/model_parser/cpp/block_desc.h index b6f473b88b84bff71650dd4ecf4d1dc803351212..c4bb756b57005cc820bae2840a8494815f4c5ecb 100644 --- a/lite/model_parser/cpp/block_desc.h +++ b/lite/model_parser/cpp/block_desc.h @@ -45,6 +45,8 @@ class BlockDesc : public BlockDescAPI { template T* GetVar(int32_t idx); + std::vector& GetVars() { return vars_; } + template T* AddVar(); diff --git a/lite/model_parser/cpp/program_desc.h b/lite/model_parser/cpp/program_desc.h index 786dad134adf8d5ac4b03ba43b254359dfc2cdb2..f935380cae08ceb5c1a5e712f553063f50255298 100644 --- a/lite/model_parser/cpp/program_desc.h +++ b/lite/model_parser/cpp/program_desc.h @@ -36,6 +36,8 @@ class ProgramDesc : public ProgramDescAPI { template T* GetBlock(int32_t idx); + std::vector& GetBlocks() { return blocks_; } + template T* AddBlock(); diff --git a/lite/tests/math/deformable_conv_compute_test.cc b/lite/tests/math/deformable_conv_compute_test.cc index e97203123d1db0752189a9965c922b048cd6bd38..d7a06c6a104ac3ac6db5d79aced6183e8bdf5963 100644 --- a/lite/tests/math/deformable_conv_compute_test.cc +++ b/lite/tests/math/deformable_conv_compute_test.cc @@ -34,7 +34,7 @@ DEFINE_int32(power_mode, DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(repeats, 1, "repeats times"); -DEFINE_bool(basic_test, true, "do all tests"); +DEFINE_bool(basic_test, false, "do all tests"); DEFINE_bool(check_result, true, "check the result"); DEFINE_int32(batch, 1, "batch size");