提交 41602396 编写于 作者: S sangoly 提交者: GitHub

add conv_elementwise_add_relu_fuse_pass test=develop (#18079)

上级 234fb8f4
...@@ -6,6 +6,7 @@ cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) ...@@ -6,6 +6,7 @@ cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
add_subdirectory(fusion) add_subdirectory(fusion)
cc_library(mir_passes cc_library(mir_passes
SRCS fc_fuse_pass.cc SRCS fc_fuse_pass.cc
conv_elementwise_add_relu_fuse_pass.cc
static_kernel_pick_pass.cc static_kernel_pick_pass.cc
variable_place_inference_pass.cc variable_place_inference_pass.cc
type_target_transform_pass.cc type_target_transform_pass.cc
...@@ -15,7 +16,7 @@ cc_library(mir_passes ...@@ -15,7 +16,7 @@ cc_library(mir_passes
argument_type_display_pass.cc argument_type_display_pass.cc
demo_pass.cc demo_pass.cc
runtime_context_assign_pass.cc runtime_context_assign_pass.cc
DEPS mir_pass types_lite context_lite mir_fusers) DEPS mir_pass types_lite context_lite ${mir_fusers})
# for mobile, unnecessary to compile the following testings. # for mobile, unnecessary to compile the following testings.
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
...@@ -74,3 +75,8 @@ lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc ...@@ -74,3 +75,8 @@ lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz")
add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz) add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz)
lite_cc_test(test_lite_conv_elementwise_add_relu_fuse
SRCS conv_elementwise_add_relu_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels})
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvElementwiseAddReLUFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvElementwiseAddReLUFuser fuser;
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass,
paddle::lite::mir::ConvElementwiseAddReLUFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ConvElementwiseAddReLUFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
auto* main_block = program_desc->MutableBlock(0);
auto* conv2d_1 = main_block->AppendOp();
auto* conv2d_2 = main_block->AppendOp();
auto* add_1 = main_block->AppendOp();
auto* add_2 = main_block->AppendOp();
auto* relu_1 = main_block->AppendOp();
auto* relu_2 = main_block->AppendOp();
main_block->Var("input_1");
main_block->Var("input_2");
main_block->Var("filter_1");
main_block->Var("filter_2");
main_block->Var("conv2d_1_out");
main_block->Var("conv2d_2_out");
main_block->Var("bias_1");
main_block->Var("add_1_out");
main_block->Var("add_2_out");
main_block->Var("relu_1_out");
main_block->Var("out");
scope->Var("input_1")->GetMutable<lite::Tensor>();
scope->Var("input_2")->GetMutable<lite::Tensor>();
scope->Var("filter_1")->GetMutable<lite::Tensor>();
scope->Var("filter_2")->GetMutable<lite::Tensor>();
scope->Var("conv2d_1_out")->GetMutable<lite::Tensor>();
scope->Var("conv2d_2_out")->GetMutable<lite::Tensor>();
scope->Var("bias_1")->GetMutable<lite::Tensor>();
scope->Var("add_1_out")->GetMutable<lite::Tensor>();
scope->Var("add_2_out")->GetMutable<lite::Tensor>();
scope->Var("relu_1_out")->GetMutable<lite::Tensor>();
scope->Var("out")->GetMutable<lite::Tensor>();
conv2d_1->SetType("conv2d");
conv2d_1->SetInput("Input", {"input_1"});
conv2d_1->SetInput("Filter", {"filter_1"});
conv2d_1->SetOutput("Output", {"conv2d_1_out"});
conv2d_1->SetAttr("strides", std::vector<int>({1, 1}));
conv2d_1->SetAttr("paddings", std::vector<int>({0, 0}));
conv2d_1->SetAttr("groups", 1);
conv2d_1->SetAttr("dilations", std::vector<int>({1, 1}));
conv2d_1->SetAttr("fuse_relu", false);
add_1->SetType("elementwise_add");
add_1->SetInput("X", {"conv2d_1_out"});
add_1->SetInput("Y", {"bias_1"});
add_1->SetOutput("Out", {"add_1_out"});
add_1->SetAttr("axis", 1);
relu_1->SetType("relu");
relu_1->SetInput("Input", {"add_1_out"});
relu_1->SetOutput("Out", {"relu_1_out"});
conv2d_2->SetType("conv2d");
conv2d_2->SetInput("Input", {"input_2"});
conv2d_2->SetInput("Filter", {"filter_2"});
conv2d_2->SetOutput("Output", {"conv2d_2_out"});
conv2d_2->SetAttr("strides", std::vector<int>({1, 1}));
conv2d_2->SetAttr("paddings", std::vector<int>({0, 0}));
conv2d_2->SetAttr("groups", 1);
conv2d_2->SetAttr("dilations", std::vector<int>({1, 1}));
conv2d_2->SetAttr("fuse_relu", false);
add_2->SetType("elementwise_add");
add_2->SetInput("X", {"conv2d_2_out"});
add_2->SetInput("Y", {"relu_1_out"});
add_2->SetOutput("Out", {"add_2_out"});
add_2->SetAttr("axis", 1);
relu_2->SetType("relu");
relu_2->SetInput("Input", {"add_2_out"});
relu_2->SetOutput("Out", {"out"});
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
return graph;
}
TEST(conv_elementwise_add_relu_fuse_pass, graph_test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
ASSERT_EQ(graph->nodes().size(),
11UL /*vars*/ + 6UL /*ops*/ + 2UL /*feed op + fetch op*/);
Visualize(graph.get());
}
TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get());
const int num_nodes = graph->nodes().size();
auto* fuser = new ConvElementwiseAddReLUFusePass;
fuser->Apply(graph);
Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(), num_nodes - 5UL * 2 /*nodes removed */ +
1UL * 2 /* fused fc node*/);
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(elementwise_add);
USE_LITE_OP(conv2d);
USE_LITE_OP(depthwise_conv2d);
USE_LITE_OP(relu);
cc_library(mir_fusers cc_library(fuse_fc
SRCS fc_fuser.cc SRCS fc_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
cc_library(conv_elementwise_add_relu
SRCS conv_elementwise_add_relu_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
conv_elementwise_add_relu
CACHE INTERNAL "fusers")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ConvElementwiseAddReLUFuser::BuildPattern() {
// create input nodes.
auto* input = VarNode("input");
auto* filter = VarNode("filter");
auto* bias = VarNode("bias");
// create op nodes
auto* conv2d = OpNode("conv2d", "conv2d");
auto* add = OpNode("add", "elementwise_add");
auto* relu = OpNode("relu", "relu");
// create intermediate nodes
auto* conv2d_out = VarNode("conv2d_out");
auto* add_out = VarNode("add_out");
// create output node
auto* out = VarNode("output");
// create topology.
std::vector<PMNode*> conv2d_inputs{filter, input};
std::vector<PMNode*> add_inputs{conv2d_out, bias};
conv2d_inputs >> *conv2d >> *conv2d_out;
add_inputs >> *add >> *add_out;
*add_out >> *relu >> *out;
// Some op specialities.
conv2d_out->AsIntermediate();
add_out->AsIntermediate();
conv2d->AsIntermediate();
add->AsIntermediate();
relu->AsIntermediate();
}
void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create("conv2d");
auto conv_old = matched.at("conv2d")->stmt()->op;
auto* scope = conv_old->scope();
auto& valid_places = conv_old->valid_places();
conv_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places);
IR_NODE_LINK_TO(matched.at("input"), new_op_node);
IR_NODE_LINK_TO(matched.at("filter"), new_op_node);
IR_NODE_LINK_TO(matched.at("bias"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) {
auto* desc = matched.at("conv2d")->stmt()->op_info();
cpp::OpDesc op_desc;
op_desc.SetType("conv2d");
op_desc.SetInput("Input", {matched.at("input")->arg()->name});
op_desc.SetInput("Filter", {matched.at("filter")->arg()->name});
op_desc.SetInput("Bias", {matched.at("bias")->arg()->name});
op_desc.SetOutput("Output", {matched.at("output")->arg()->name});
// Other inputs. See operators/conv_op.h
std::vector<std::string> input_arg_names = desc->InputArgumentNames();
for (auto name : input_arg_names) LOG(INFO) << name;
if (std::find(input_arg_names.begin(), input_arg_names.end(),
"ResidualData") != input_arg_names.end()) {
op_desc.SetInput("ResidualData", desc->Input("ResidualData"));
}
// Only consider strides, padding, groups, dilations, fuse_relu for now
op_desc.SetAttr("strides", desc->GetAttr<std::vector<int>>("strides"));
op_desc.SetAttr("paddings", desc->GetAttr<std::vector<int>>("paddings"));
op_desc.SetAttr("groups", desc->GetAttr<int>("groups"));
op_desc.SetAttr("dilations", desc->GetAttr<std::vector<int>>("dilations"));
op_desc.SetAttr("fuse_relu", true);
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ConvElementwiseAddReLUFuser : public FuseBase {
public:
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -23,6 +23,7 @@ namespace mir {} // namespace mir ...@@ -23,6 +23,7 @@ namespace mir {} // namespace mir
USE_MIR_PASS(demo); USE_MIR_PASS(demo);
USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass);
USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(variable_place_inference_pass); USE_MIR_PASS(variable_place_inference_pass);
USE_MIR_PASS(type_target_transform_pass); USE_MIR_PASS(type_target_transform_pass);
......
...@@ -64,17 +64,31 @@ class ConvOpLite : public OpLite { ...@@ -64,17 +64,31 @@ class ConvOpLite : public OpLite {
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
auto X = op_desc.Input("Input").front(); auto X = op_desc.Input("Input").front();
auto Filter = op_desc.Input("Filter").front(); auto Filter = op_desc.Input("Filter").front();
auto Bias = op_desc.Input("Bias").front();
// auto ResidualData = op_desc.Input("ResidualData");
auto Out = op_desc.Output("Output").front(); auto Out = op_desc.Output("Output").front();
param_.x = scope->FindVar(X)->GetMutable<lite::Tensor>(); param_.x = scope->FindVar(X)->GetMutable<lite::Tensor>();
param_.filter = scope->FindVar(Filter)->GetMutable<lite::Tensor>(); param_.filter = scope->FindVar(Filter)->GetMutable<lite::Tensor>();
param_.bias = scope->FindVar(Bias)->GetMutable<lite::Tensor>();
// param_.residualData =
// scope->FindVar(ResidualData)->GetMutable<lite::Tensor>();
param_.output = scope->FindVar(Out)->GetMutable<lite::Tensor>(); param_.output = scope->FindVar(Out)->GetMutable<lite::Tensor>();
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end()) {
auto bias_var = scope->FindVar(op_desc.Input("Bias").front());
if (bias_var != nullptr) {
param_.bias =
const_cast<lite::Tensor*>(&(bias_var->Get<lite::Tensor>()));
}
}
if (std::find(input_arg_names.begin(), input_arg_names.end(),
"ResidualData") != input_arg_names.end()) {
auto residual_data_var =
scope->FindVar(op_desc.Input("ResidualData").front());
if (residual_data_var != nullptr) {
param_.residualData = const_cast<lite::Tensor*>(
&(residual_data_var->Get<lite::Tensor>()));
}
}
param_.strides = op_desc.GetAttr<std::vector<int>>("strides"); param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings"); param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.groups = op_desc.GetAttr<int>("groups"); param_.groups = op_desc.GetAttr<int>("groups");
......
...@@ -37,7 +37,6 @@ bool ReluOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -37,7 +37,6 @@ bool ReluOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>(); scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
CHECK(param_.input); CHECK(param_.input);
CHECK(param_.output); CHECK(param_.output);
kernel_->SetParam(param_);
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册