提交 de2a72a4 编写于 作者: S sangoly

add elementwise_add_activation fuse pass & op

上级 09b35192
...@@ -7,7 +7,8 @@ cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) ...@@ -7,7 +7,8 @@ cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
add_subdirectory(fusion) add_subdirectory(fusion)
cc_library(mir_passes cc_library(mir_passes
SRCS fc_fuse_pass.cc SRCS fc_fuse_pass.cc
conv_elementwise_add_relu_fuse_pass.cc conv_elementwise_add_activation_fuse_pass.cc
elementwise_add_activation_fuse_pass.cc
conv_bn_fuse_pass.cc conv_bn_fuse_pass.cc
static_kernel_pick_pass.cc static_kernel_pick_pass.cc
variable_place_inference_pass.cc variable_place_inference_pass.cc
...@@ -82,7 +83,11 @@ lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz ...@@ -82,7 +83,11 @@ lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz
add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz) add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz)
lite_cc_test(test_lite_conv_elementwise_add_relu_fuse lite_cc_test(test_lite_conv_elementwise_add_activation_fuse
SRCS conv_elementwise_add_relu_fuse_pass_test.cc SRCS conv_elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels})
lite_cc_test(test_lite_elementwise_add_activation_fuse
SRCS elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels}) ${ops_lite} ${host_kernels} ${x86_kernels})
...@@ -12,22 +12,23 @@ ...@@ -12,22 +12,23 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h" #include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h" #include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h" #include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
void ConvElementwiseAddReLUFusePass::Apply( void ConvElementwiseAddActivationFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) { const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvElementwiseAddReLUFuser fuser("conv2d"); fusion::ConvElementwiseAddActivationFuser fuser("conv2d", "relu");
fuser(graph.get()); fuser(graph.get());
fusion::ConvElementwiseAddReLUFuser depthwise_fuser("depthwise_conv2d"); fusion::ConvElementwiseAddActivationFuser depthwise_fuser("depthwise_conv2d",
"relu");
depthwise_fuser(graph.get()); depthwise_fuser(graph.get());
} }
...@@ -35,5 +36,5 @@ void ConvElementwiseAddReLUFusePass::Apply( ...@@ -35,5 +36,5 @@ void ConvElementwiseAddReLUFusePass::Apply(
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass, REGISTER_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass,
paddle::lite::mir::ConvElementwiseAddReLUFusePass); paddle::lite::mir::ConvElementwiseAddActivationFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ConvElementwiseAddActivationFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h" #include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
...@@ -135,11 +135,11 @@ TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) { ...@@ -135,11 +135,11 @@ TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) {
auto graph = BuildGraph(&program_desc, scope, places); auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get()); Visualize(graph.get());
const int num_nodes = graph->nodes().size(); const int num_nodes = graph->nodes().size();
auto* fuser = new ConvElementwiseAddReLUFusePass; auto* fuser = new ConvElementwiseAddActivationFusePass;
fuser->Apply(graph); fuser->Apply(graph);
Visualize(graph.get()); Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(), num_nodes - 5UL * 2 /*nodes removed */ + ASSERT_EQ(graph->nodes().size(),
1UL * 2 /* fused fc node*/); num_nodes - 5UL * 2 /*nodes removed */ + 1UL * 2 /* fused nodes*/);
} }
} // namespace fusion } // namespace fusion
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ElementwiseAddActivationFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::ElementwiseAddActivationFuser fuser("relu");
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass,
paddle::lite::mir::ElementwiseAddActivationFusePass);
...@@ -22,7 +22,7 @@ namespace paddle { ...@@ -22,7 +22,7 @@ namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
class ConvElementwiseAddReLUFusePass : public ProgramPass { class ElementwiseAddActivationFusePass : public ProgramPass {
public: public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override; void Apply(const std::unique_ptr<SSAGraph>& graph) override;
}; };
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
auto* main_block = program_desc->MutableBlock(0);
auto* add_1 = main_block->AppendOp();
auto* add_2 = main_block->AppendOp();
auto* relu_1 = main_block->AppendOp();
auto* relu_2 = main_block->AppendOp();
main_block->Var("x_1");
main_block->Var("y_1");
main_block->Var("add_out_1");
main_block->Var("relu_out_1");
main_block->Var("y_2");
main_block->Var("add_out_2");
main_block->Var("out");
scope->Var("x_1")->GetMutable<lite::Tensor>();
scope->Var("y_1")->GetMutable<lite::Tensor>();
scope->Var("add_out_1")->GetMutable<lite::Tensor>();
scope->Var("relu_out_1")->GetMutable<lite::Tensor>();
scope->Var("y_2")->GetMutable<lite::Tensor>();
scope->Var("add_out_2")->GetMutable<lite::Tensor>();
scope->Var("out")->GetMutable<lite::Tensor>();
add_1->SetType("elementwise_add");
add_1->SetInput("X", {"x_1"});
add_1->SetInput("Y", {"y_1"});
add_1->SetOutput("Out", {"add_out_1"});
add_1->SetAttr("axis", 1);
relu_1->SetType("relu");
relu_1->SetInput("X", {"add_out_1"});
relu_1->SetOutput("Out", {"relu_out_1"});
add_2->SetType("elementwise_add");
add_2->SetInput("X", {"relu_out_1"});
add_2->SetInput("Y", {"y_2"});
add_2->SetOutput("Out", {"add_out_2"});
add_2->SetAttr("axis", 1);
relu_2->SetType("relu");
relu_2->SetInput("X", {"add_out_2"});
relu_2->SetOutput("Out", {"out"});
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
return graph;
}
TEST(elementwise_add_activation_fuse_pass, graph_test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
ASSERT_EQ(graph->nodes().size(),
7UL /*vars*/ + 4UL /*ops*/ + 1UL /* SSAGraph tmp node*/);
}
TEST(elementwise_add_activation_fuse_pass, fuse_test_op) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get());
const int num_nodes = graph->nodes().size();
auto* fuser = new ElementwiseAddActivationFusePass;
fuser->Apply(graph);
Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(),
num_nodes - 3UL * 2 /*nodes removed */ + 1UL * 2 /* fused nodes*/);
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(elementwise_add);
USE_LITE_OP(fusion_elementwise_add_activation);
USE_LITE_OP(relu);
cc_library(fuse_fc cc_library(fuse_fc
SRCS fc_fuser.cc SRCS fc_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
cc_library(fuse_conv_elementwise_add_relu cc_library(fuse_conv_elementwise_add_activation
SRCS conv_elementwise_add_relu_fuser.cc SRCS conv_elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
cc_library(fuse_conv_bn cc_library(fuse_conv_bn
SRCS conv_bn_fuser.cc SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
cc_library(fuse_elementwise_add_activation
SRCS elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers set(mir_fusers
fuse_fc fuse_fc
fuse_conv_elementwise_add_relu fuse_conv_elementwise_add_activation
fuse_conv_bn fuse_conv_bn
fuse_elementwise_add_activation
CACHE INTERNAL "fusers") CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h" #include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -21,7 +21,7 @@ namespace lite { ...@@ -21,7 +21,7 @@ namespace lite {
namespace mir { namespace mir {
namespace fusion { namespace fusion {
void ConvElementwiseAddReLUFuser::BuildPattern() { void ConvElementwiseAddActivationFuser::BuildPattern() {
// create input nodes. // create input nodes.
auto* input = auto* input =
VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput(); VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput();
...@@ -36,7 +36,8 @@ void ConvElementwiseAddReLUFuser::BuildPattern() { ...@@ -36,7 +36,8 @@ void ConvElementwiseAddReLUFuser::BuildPattern() {
auto* add = OpNode("add", "elementwise_add") auto* add = OpNode("add", "elementwise_add")
->assert_is_op("elementwise_add") ->assert_is_op("elementwise_add")
->AsIntermediate(); ->AsIntermediate();
auto* relu = OpNode("relu", "relu")->assert_is_op("relu")->AsIntermediate(); auto* act =
OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate();
// create intermediate nodes // create intermediate nodes
auto* conv2d_out = VarNode("conv2d_out") auto* conv2d_out = VarNode("conv2d_out")
...@@ -45,22 +46,23 @@ void ConvElementwiseAddReLUFuser::BuildPattern() { ...@@ -45,22 +46,23 @@ void ConvElementwiseAddReLUFuser::BuildPattern() {
->AsIntermediate(); ->AsIntermediate();
auto* add_out = VarNode("add_out") auto* add_out = VarNode("add_out")
->assert_is_op_output("elementwise_add", "Out") ->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("relu", "X") ->assert_is_op_input(act_type_, "X")
->AsIntermediate(); ->AsIntermediate();
// create output node // create output node
auto* out = VarNode("output")->assert_is_op_output("relu", "Out")->AsOutput(); auto* out =
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology. // create topology.
std::vector<PMNode*> conv2d_inputs{filter, input}; std::vector<PMNode*> conv2d_inputs{filter, input};
std::vector<PMNode*> add_inputs{conv2d_out, bias}; std::vector<PMNode*> add_inputs{conv2d_out, bias};
conv2d_inputs >> *conv2d >> *conv2d_out; conv2d_inputs >> *conv2d >> *conv2d_out;
add_inputs >> *add >> *add_out; add_inputs >> *add >> *add_out;
*add_out >> *relu >> *out; *add_out >> *act >> *out;
} }
void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph, void ConvElementwiseAddActivationFuser::InsertNewNode(
const key2nodes_t& matched) { SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched); auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create(conv_type_); auto conv_op = LiteOpRegistry::Global().Create(conv_type_);
auto conv_old = matched.at("conv2d")->stmt()->op; auto conv_old = matched.at("conv2d")->stmt()->op;
...@@ -76,7 +78,8 @@ void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph, ...@@ -76,7 +78,8 @@ void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(new_op_node, matched.at("output")); IR_NODE_LINK_TO(new_op_node, matched.at("output"));
} }
cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc ConvElementwiseAddActivationFuser::GenOpDesc(
const key2nodes_t& matched) {
auto* desc = matched.at("conv2d")->stmt()->op_info(); auto* desc = matched.at("conv2d")->stmt()->op_info();
cpp::OpDesc op_desc; cpp::OpDesc op_desc;
...@@ -98,6 +101,7 @@ cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -98,6 +101,7 @@ cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc.SetAttr("paddings", desc->GetAttr<std::vector<int>>("paddings")); op_desc.SetAttr("paddings", desc->GetAttr<std::vector<int>>("paddings"));
op_desc.SetAttr("groups", desc->GetAttr<int>("groups")); op_desc.SetAttr("groups", desc->GetAttr<int>("groups"));
op_desc.SetAttr("dilations", desc->GetAttr<std::vector<int>>("dilations")); op_desc.SetAttr("dilations", desc->GetAttr<std::vector<int>>("dilations"));
// TODO(sangoly): support other activation types
op_desc.SetAttr("fuse_relu", true); op_desc.SetAttr("fuse_relu", true);
return op_desc; return op_desc;
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ConvElementwiseAddActivationFuser : public FuseBase {
public:
explicit ConvElementwiseAddActivationFuser(const std::string& conv_type,
const std::string& act_type) {
CHECK(act_type == "relu") << "Only relu activation be supported now";
conv_type_ = conv_type;
act_type_ = act_type;
}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string conv_type_;
std::string act_type_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ElementwiseAddActivationFuser::BuildPattern() {
// create input nodes.
auto* x = VarNode("x")->assert_is_op_input("elementwise_add", "X")->AsInput();
auto* y = VarNode("y")->assert_is_op_input("elementwise_add", "Y")->AsInput();
// create op nodes
auto* add = OpNode("add", "elementwise_add")
->assert_is_op("elementwise_add")
->AsIntermediate();
auto* act =
OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate();
// create intermediate nodes
auto* add_out = VarNode("add_out")
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input(act_type_, "X")
->AsIntermediate();
// create output node
auto* out =
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology.
std::vector<PMNode*> add_inputs{x, y};
add_inputs >> *add >> *add_out;
*add_out >> *act >> *out;
}
void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto op =
LiteOpRegistry::Global().Create("fusion_elementwise_add_activation");
auto old_op = matched.at("add")->stmt()->op;
auto* scope = old_op->scope();
auto& valid_places = old_op->valid_places();
op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(matched.at("y"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
cpp::OpDesc ElementwiseAddActivationFuser::GenOpDesc(
const key2nodes_t& matched) {
auto* desc = matched.at("add")->stmt()->op_info();
cpp::OpDesc op_desc;
op_desc.SetType("fusion_elementwise_add_activation");
op_desc.SetInput("X", {matched.at("x")->arg()->name});
op_desc.SetInput("Y", {matched.at("y")->arg()->name});
op_desc.SetOutput("Out", {matched.at("output")->arg()->name});
op_desc.SetAttr("axis", desc->GetAttr<int>("axis"));
op_desc.SetAttr("act_type", act_type_);
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -23,16 +23,16 @@ namespace lite { ...@@ -23,16 +23,16 @@ namespace lite {
namespace mir { namespace mir {
namespace fusion { namespace fusion {
class ConvElementwiseAddReLUFuser : public FuseBase { class ElementwiseAddActivationFuser : public FuseBase {
public: public:
explicit ConvElementwiseAddReLUFuser(const std::string& conv_type) explicit ElementwiseAddActivationFuser(const std::string& act_type)
: conv_type_(conv_type) {} : act_type_(act_type) {}
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private: private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string conv_type_; std::string act_type_;
}; };
} // namespace fusion } // namespace fusion
......
...@@ -34,4 +34,5 @@ USE_MIR_PASS(runtime_context_assign_pass); ...@@ -34,4 +34,5 @@ USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(graph_visualze); USE_MIR_PASS(graph_visualze);
USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass); USE_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass);
USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass);
...@@ -48,9 +48,12 @@ class Optimizer { ...@@ -48,9 +48,12 @@ class Optimizer {
if (passes.empty()) { if (passes.empty()) {
RunPasses(std::vector<std::string>{{ RunPasses(std::vector<std::string>{{
"lite_conv_bn_fuse_pass", // "lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_add_act_fuse_pass", // "lite_conv_elementwise_add_activation_fuse_pass", //
"lite_fc_fuse_pass", // "lite_fc_fuse_pass", //
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass", //
#endif
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"static_kernel_pick_pass", // "static_kernel_pick_pass", //
"variable_place_inference_pass", // "variable_place_inference_pass", //
......
...@@ -14,6 +14,7 @@ cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS}) ...@@ -14,6 +14,7 @@ cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS})
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS}) cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS})
cc_library(activation_ops_lite SRCS activation_ops.cc DEPS ${op_DEPS}) cc_library(activation_ops_lite SRCS activation_ops.cc DEPS ${op_DEPS})
cc_library(elementwise_ops_lite SRCS elementwise_ops.cc DEPS ${op_DEPS}) cc_library(elementwise_ops_lite SRCS elementwise_ops.cc DEPS ${op_DEPS})
cc_library(fusion_elementwise_activation_ops_lite SRCS fusion_elementwise_activation_ops.cc DEPS elementwise_ops_lite ${op_DEPS})
cc_library(mean_op_lite SRCS mean_op.cc DEPS ${op_DEPS}) cc_library(mean_op_lite SRCS mean_op.cc DEPS ${op_DEPS})
cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS})
#cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS}) #cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS})
...@@ -36,6 +37,7 @@ set(ops_lite ...@@ -36,6 +37,7 @@ set(ops_lite
fetch_op_lite fetch_op_lite
io_copy_op_lite io_copy_op_lite
elementwise_ops_lite elementwise_ops_lite
fusion_elementwise_activation_ops_lite
mean_op_lite mean_op_lite
fill_constant_op_lite fill_constant_op_lite
activation_ops_lite activation_ops_lite
...@@ -56,3 +58,6 @@ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite m ...@@ -56,3 +58,6 @@ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite m
lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite) lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite)
lite_cc_test(test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite) lite_cc_test(test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite)
lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite) lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite)
lite_cc_test(test_fusion_elementwise_activation_ops_lite
SRCS fusion_elementwise_activation_ops_test.cc
DEPS fusion_elementwise_activation_ops_lite memory_lite)
...@@ -12,92 +12,67 @@ ...@@ -12,92 +12,67 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/operators/elementwise_ops.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
class ElementwiseOp : public OpLite { bool ElementwiseOp::CheckShape() const {
public: CHECK_OR_FALSE(param_.X);
explicit ElementwiseOp(const std::string& type) : OpLite(type) {} CHECK_OR_FALSE(param_.Y);
CHECK_OR_FALSE(param_.Out);
bool CheckShape() const override { return true;
CHECK_OR_FALSE(param_.X); }
CHECK_OR_FALSE(param_.Y);
CHECK_OR_FALSE(param_.Out); bool ElementwiseOp::InferShape() const {
return true; CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size());
} param_.Out->Resize(param_.X->dims());
return true;
bool InferShape() const override { }
CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size());
param_.Out->Resize(param_.X->dims()); bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
return true; auto X_name = opdesc.Input("X").front();
} auto Y_name = opdesc.Input("Y").front();
auto Out_name = opdesc.Output("Out").front();
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
auto X_name = opdesc.Input("X").front(); param_.X = GetVar<lite::Tensor>(scope, X_name);
auto Y_name = opdesc.Input("Y").front(); param_.Y = GetVar<lite::Tensor>(scope, Y_name);
auto Out_name = opdesc.Output("Out").front(); param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name);
param_.axis = opdesc.GetAttr<int>("axis");
param_.X = GetVar<lite::Tensor>(scope, X_name); return true;
param_.Y = GetVar<lite::Tensor>(scope, Y_name); }
param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name);
param_.axis = opdesc.GetAttr<int>("axis");
return true;
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "elementwise_op"; }
private:
mutable operators::ElementwiseParam param_;
};
#ifdef LITE_WITH_X86 #ifdef LITE_WITH_X86
class ElementwiseGradExplicitOp : public OpLite { bool ElementwiseGradExplicitOp::CheckShape() const {
public: CHECK_OR_FALSE(param_.Y);
explicit ElementwiseGradExplicitOp(const std::string& type) : OpLite(type) {} CHECK_OR_FALSE(param_.X_grad);
CHECK_OR_FALSE(param_.Y_grad);
bool CheckShape() const override { CHECK_OR_FALSE(param_.Out_grad);
CHECK_OR_FALSE(param_.Y); return true;
CHECK_OR_FALSE(param_.X_grad); }
CHECK_OR_FALSE(param_.Y_grad);
CHECK_OR_FALSE(param_.Out_grad); bool ElementwiseGradExplicitOp::InferShape() const {
return true; param_.X_grad->Resize(param_.Out_grad->dims());
} param_.Y_grad->Resize(param_.Y->dims());
return true;
bool InferShape() const override { }
param_.X_grad->Resize(param_.Out_grad->dims());
param_.Y_grad->Resize(param_.Y->dims()); bool ElementwiseGradExplicitOp::AttachImpl(const cpp::OpDesc& opdesc,
return true; lite::Scope* scope) {
} CHECK_EQ(opdesc.InputArgumentNames().size(), 1UL);
auto Out_name = opdesc.Input(framework::GradVarName("Out")).front();
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { auto X_name = opdesc.Output(framework::GradVarName("X")).front();
CHECK_EQ(opdesc.InputArgumentNames().size(), 1UL); auto Y_name = opdesc.Output(framework::GradVarName("Y")).front();
auto Out_name = opdesc.Input(framework::GradVarName("Out")).front();
auto X_name = opdesc.Output(framework::GradVarName("X")).front(); param_.Out_grad = GetVar<lite::Tensor>(scope, Out_name);
auto Y_name = opdesc.Output(framework::GradVarName("Y")).front(); param_.X_grad = GetMutableVar<lite::Tensor>(scope, X_name);
param_.Y_grad = GetMutableVar<Tensor>(scope, Y_name);
param_.Out_grad = GetVar<lite::Tensor>(scope, Out_name); param_.axis = opdesc.GetAttr<int>("axis");
param_.X_grad = GetMutableVar<lite::Tensor>(scope, X_name);
param_.Y_grad = GetMutableVar<Tensor>(scope, Y_name); return true;
param_.axis = opdesc.GetAttr<int>("axis"); }
return true;
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "elementwise_grad_explicit_op";
}
private:
mutable operators::ElementwiseGradParam param_;
};
#endif #endif
} // namespace operators } // namespace operators
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class ElementwiseOp : public OpLite {
public:
explicit ElementwiseOp(const std::string& op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "elementwise_op"; }
private:
mutable operators::ElementwiseParam param_;
};
#ifdef LITE_WITH_X86
class ElementwiseGradExplicitOp : public OpLite {
public:
explicit ElementwiseGradExplicitOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "elementwise_grad_explicit_op";
}
private:
mutable operators::ElementwiseGradParam param_;
};
#endif
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h"
#include <string>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) {
ElementwiseOp::AttachImpl(opdesc, scope);
param_.act_type = opdesc.GetAttr<std::string>("act_type");
// TODO(sangoly): support more activation types.
CHECK(param_.act_type == "relu") << "Only relu activation be supported now";
return true;
}
#ifdef LITE_WITH_X86
bool FusionElementwiseActivationGradExplicitOp::AttachImpl(
const cpp::OpDesc& opdesc, lite::Scope* scope) {
ElementwiseGradExplicitOp::AttachImpl(opdesc, scope);
param_.act_type = opdesc.GetAttr<std::string>("act_type");
// TODO(sangoly): support more activation types.
CHECK(param_.act_type == "relu") << "Only relu activation be supported now";
return true;
}
#endif
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(fusion_elementwise_sub_activation,
paddle::lite::operators::FusionElementwiseActivationOp);
#ifdef LITE_WITH_X86
REGISTER_LITE_OP(
fusion_elementwise_sub_activation_grad,
paddle::lite::operators::FusionElementwiseActivationGradExplicitOp);
#endif
REGISTER_LITE_OP(fusion_elementwise_add_activation,
paddle::lite::operators::FusionElementwiseActivationOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/operators/elementwise_ops.h"
namespace paddle {
namespace lite {
namespace operators {
class FusionElementwiseActivationOp : public ElementwiseOp {
public:
explicit FusionElementwiseActivationOp(const std::string& type)
: ElementwiseOp(type) {}
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
std::string DebugString() const override {
return "fusion_elementwise_activation_op";
}
private:
mutable operators::FusionElementwiseActivationParam param_;
};
#ifdef LITE_WITH_X86
class FusionElementwiseActivationGradExplicitOp
: public ElementwiseGradExplicitOp {
public:
explicit FusionElementwiseActivationGradExplicitOp(const std::string& type)
: ElementwiseGradExplicitOp(type) {}
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
std::string DebugString() const override {
return "fusion_elementwise_activation_grad_explicit_op";
}
private:
mutable operators::FusionElementwiseActivationGradParam param_;
};
#endif
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h"
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(fusion_elementwise_activation_op_lite, test) {
// prepare variables
lite::Scope scope;
auto* x = scope.Var("x")->GetMutable<lite::Tensor>();
auto* y = scope.Var("y")->GetMutable<lite::Tensor>();
auto* out = scope.Var("out")->GetMutable<lite::Tensor>();
x->Resize(lite::DDim(std::vector<int64_t>({10, 20})));
y->Resize(lite::DDim(std::vector<int64_t>({10, 20})));
out->Resize(lite::DDim(std::vector<int64_t>{10, 20}));
// set data
for (int i = 0; i < 10 * 20; i++) {
x->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 10 * 20; i++) {
y->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 10 * 20; i++) {
out->mutable_data<float>()[i] = 0.;
}
// prepare op desc
cpp::OpDesc desc;
desc.SetType("fusion_elementwise_add_activation");
desc.SetInput("X", {"x"});
desc.SetInput("Y", {"y"});
desc.SetOutput("Out", {"out"});
desc.SetAttr("axis", static_cast<int>(1));
desc.SetAttr("act_type", std::string("relu"));
FusionElementwiseActivationOp fuse_op("fusion_elementwise_add_activation");
fuse_op.SetValidPlaces({Place{TARGET(kX86), PRECISION(kFloat)}});
fuse_op.Attach(desc, &scope);
}
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -219,6 +219,14 @@ struct ElementwiseGradParam { ...@@ -219,6 +219,14 @@ struct ElementwiseGradParam {
int axis{-1}; // for broadcasting. int axis{-1}; // for broadcasting.
}; };
struct FusionElementwiseActivationParam : public ElementwiseParam {
std::string act_type;
};
struct FusionElementwiseActivationGradParam : public ElementwiseGradParam {
std::string act_type;
};
/// ----------------------- activation operators ---------------------- /// ----------------------- activation operators ----------------------
struct ActivationParam { struct ActivationParam {
const lite::Tensor* X{}; const lite::Tensor* X{};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册