未验证 提交 1a64347a 编写于 作者: H HappyAngel 提交者: GitHub

[arm] add scale+relu/relu6/leakyrelu fusion (#3461)

* add scale+relu/relu6/leakyrelu test=develop
* fix format, test=develop
上级 4d495329
......@@ -52,6 +52,7 @@ USE_MIR_PASS(mlu_postprocess_pass);
USE_MIR_PASS(weight_quantization_preprocess_pass);
USE_MIR_PASS(apu_subgraph_pass);
USE_MIR_PASS(quantized_op_attributes_inference_pass);
USE_MIR_PASS(lite_scale_activation_fuse_pass);
USE_MIR_PASS(__xpu__resnet_fuse_pass);
USE_MIR_PASS(__xpu__multi_encoder_fuse_pass);
USE_MIR_PASS(__xpu__embedding_with_eltwise_add_fuse_pass);
......
此差异已折叠。
......@@ -40,6 +40,15 @@ void scale_compute_basic(const operators::ScaleParam& param) {
template <typename T>
void scale(const T* din, T* dout, int num, T scale, T bias);
template <typename T>
void scale_relu(const T* din, T* dout, int num, T scale, T bias);
template <typename T>
void scale_relu6(const T* din, T* dout, int num, T scale, T bias, T alpha);
template <typename T>
void scale_leaky_relu(const T* din, T* dout, int num, T scale, T bias, T alpha);
template <typename T>
void scale(const T* din,
T* dout,
......
......@@ -21,6 +21,7 @@ lite_cc_library(mir_passes
fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc
fusion/sequence_pool_concat_fuse_pass.cc
fusion/scale_activation_fuse_pass.cc
fusion/__xpu__resnet_fuse_pass.cc
fusion/__xpu__multi_encoder_fuse_pass.cc
fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc
......
......@@ -31,6 +31,9 @@ lite_cc_library(fuse_interpolate
lite_cc_library(fuse_sequence_pool_concat
SRCS sequence_pool_concat_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_scale_activation
SRCS scale_activation_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
......@@ -44,6 +47,7 @@ set(mir_fusers
fuse_transpose_softmax_transpose
fuse_interpolate
fuse_sequence_pool_concat
fuse_scale_activation
CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/scale_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/scale_activation_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ScaleActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
for (auto act_type : {"relu", "relu6", "leaky_relu"}) {
fusion::ScaleActivationFuser fuser(act_type);
fuser(graph.get());
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_scale_activation_fuse_pass,
paddle::lite::mir::ScaleActivationFusePass)
.BindTargets({TARGET(kARM)})
.BindKernel("scale");
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ScaleActivationFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/scale_activation_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ScaleActivationFuser::BuildPattern() {
// create input nodes.
auto* x = VarNode("x")->assert_is_op_input("scale", "X")->AsInput();
// create op nodes
auto* scale =
OpNode("scale", "scale")->assert_is_op("scale")->AsIntermediate();
auto* act =
OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate();
// create intermediate nodes
auto* scale_out = VarNode("scale_out")
->assert_is_op_output("scale", "Out")
->assert_is_op_input(act_type_, "X")
->AsIntermediate();
// create output node
auto* out =
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology.
*x >> *scale >> *scale_out;
*scale_out >> *act >> *out;
}
void ScaleActivationFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto scale_op = LiteOpRegistry::Global().Create("scale");
auto scale = matched.at("scale")->stmt()->op();
auto* scope = scale->scope();
auto& valid_places = scale->valid_places();
scale_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(scale_op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
cpp::OpDesc ScaleActivationFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("scale")->stmt()->op_info();
op_desc.SetOutput("Out", {matched.at("output")->arg()->name});
cpp::OpDesc act_op_desc = *matched.at("act")->stmt()->op_info();
op_desc.SetAttr("activation_type", act_type_);
if (act_type_ == "relu") {
op_desc.SetAttr("fuse_relu", true);
} else if (act_type_ == "relu6") {
float alpha = act_op_desc.GetAttr<float>("threshold");
op_desc.SetAttr("alpha", alpha);
} else if (act_type_ == "leaky_relu") {
float alpha = act_op_desc.GetAttr<float>("alpha");
op_desc.SetAttr("alpha", alpha);
}
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ScaleActivationFuser : public FuseBase {
public:
explicit ScaleActivationFuser(const std::string& act_type) {
act_type_ = act_type;
}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string act_type_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -71,6 +71,7 @@ class Optimizer {
"identity_scale_eliminate_pass", //
"elementwise_mul_constant_eliminate_pass", //
"lite_sequence_pool_concat_fuse_pass", //
"lite_scale_activation_fuse_pass", //
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \
(defined LITE_WITH_ARM)
"lite_elementwise_add_activation_fuse_pass", //
......
......@@ -31,7 +31,18 @@ void ScaleCompute<T, PType>::Run() {
if (!param.bias_after_scale) {
bias *= scale;
}
lite::arm::math::scale<T>(x_data, output_data, num, scale, bias);
T alpha = param.alpha;
if (param.activation_type == "") { // no act
lite::arm::math::scale<T>(x_data, output_data, num, scale, bias);
} else if (param.activation_type == "relu") { // do relu
lite::arm::math::scale_relu<T>(x_data, output_data, num, scale, bias);
} else if (param.activation_type == "relu6") { // do relu6
lite::arm::math::scale_relu6<T>(
x_data, output_data, num, scale, bias, alpha);
} else if (param.activation_type == "leaky_relu") { // do leaky_relu
lite::arm::math::scale_leaky_relu<T>(
x_data, output_data, num, scale, bias, alpha);
}
if (!param.x->lod().empty()) {
param.output->set_lod(param.x->lod());
}
......
......@@ -244,6 +244,9 @@ struct ScaleParam : ParamBase {
float scale{1.};
float bias{};
bool bias_after_scale{true};
std::string activation_type{""};
bool fuse_relu{false};
float alpha{6.};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const std::vector<const Tensor*>* input_tensor_ptrs() override {
......
......@@ -38,6 +38,20 @@ bool ScaleOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.scale = op_desc.GetAttr<float>("scale");
param_.bias = op_desc.GetAttr<float>("bias");
param_.bias_after_scale = op_desc.GetAttr<bool>("bias_after_scale");
if (op_desc.HasAttr("activation_type")) {
auto act_type = op_desc.GetAttr<std::string>("activation_type");
param_.activation_type = act_type;
if (act_type == "relu") {
param_.fuse_relu = true;
} else if (act_type == "relu6") {
param_.alpha = op_desc.GetAttr<float>("alpha"); // 6.f
} else if (act_type == "leaky_relu") {
param_.alpha = op_desc.GetAttr<float>("alpha");
} else {
CHECK(false)
<< "The fused conv only supports fuse with relu and leaky relu";
}
}
CHECK(param_.x);
CHECK(param_.output);
return true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册