提交 810eb089 编写于 作者: W Wilber 提交者: GitHub

add var_conv_2d_relu pass test=develop (#2631)

add var_conv_2d + relu fuse pass
上级 3ef94cd4
......@@ -34,6 +34,7 @@ USE_MIR_PASS(lite_interpolate_fuse_pass);
USE_MIR_PASS(identity_scale_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_fuse_pass);
USE_MIR_PASS(lite_conv_activation_fuse_pass);
USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass);
USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass);
USE_MIR_PASS(lite_quant_dequant_fuse_pass);
USE_MIR_PASS(type_precision_cast_pass);
......
......@@ -16,6 +16,7 @@ lite_cc_library(mir_passes
fusion/interpolate_fuse_pass.cc
fusion/conv_elementwise_fuse_pass.cc
fusion/conv_activation_fuse_pass.cc
fusion/var_conv_2d_activation_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc
fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc
......
......@@ -10,6 +10,9 @@ lite_cc_library(fuse_conv_elementwise
lite_cc_library(fuse_conv_activation
SRCS conv_activation_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_var_conv_activation
SRCS var_conv_2d_activation_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_conv_bn
SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api)
......@@ -31,6 +34,7 @@ set(mir_fusers
fuse_shuffle_channel
fuse_conv_elementwise
fuse_conv_activation
fuse_var_conv_activation
fuse_conv_bn
fuse_quant_dequant
fuse_elementwise_add_activation
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/var_conv_2d_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/var_conv_2d_activation_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void VarConv2dActivationFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
std::vector<std::string> act_types{"relu"};
for (auto act_type : act_types) {
fusion::VarConvActivationFuser fuser(act_type, "var_conv_2d");
fuser(graph.get());
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_var_conv_2d_activation_fuse_pass,
paddle::lite::mir::VarConv2dActivationFusePass)
.BindTargets({TARGET(kCUDA)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class VarConv2dActivationFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/var_conv_2d_activation_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void VarConvActivationFuser::BuildPattern() {
// create nodes.
auto* input = VarNode("X")->assert_is_op_input(conv_type_, "X")->AsInput();
auto* filter = VarNode("W")->assert_is_op_input(conv_type_, "W")->AsInput();
auto* conv2d = OpNode("var_conv_2d", conv_type_)->AsIntermediate();
auto* act = OpNode("act", act_type_)->AsIntermediate();
auto* conv2d_out = VarNode("conv2d_out")
->assert_is_op_output(conv_type_, "Out")
->assert_is_op_input(act_type_, "X")
->AsIntermediate();
auto* conv2d_out_1 = VarNode("conv2d_out_1")
->assert_is_op_output(conv_type_, "Col")
->AsIntermediate();
auto* out =
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology.
std::vector<PMNode*> conv2d_inputs{filter, input};
conv2d_inputs >> *conv2d >> *conv2d_out >> *act >> *out;
*conv2d >> *conv2d_out_1;
}
void VarConvActivationFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create(conv_type_);
auto conv_old = matched.at("var_conv_2d")->stmt()->op();
auto* scope = conv_old->scope();
auto& valid_places = conv_old->valid_places();
conv_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places);
IR_NODE_LINK_TO(matched.at("X"), new_op_node);
IR_NODE_LINK_TO(matched.at("W"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
cpp::OpDesc VarConvActivationFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("var_conv_2d")->stmt()->op_info();
op_desc.SetOutput("Out", {matched.at("output")->arg()->name});
cpp::OpDesc act_op_desc = *matched.at("act")->stmt()->op_info();
if (act_type_ == "relu") {
op_desc.SetAttr("fuse_relu", true);
}
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class VarConvActivationFuser : public FuseBase {
public:
explicit VarConvActivationFuser(const std::string& act_type,
const std::string& conv_type)
: act_type_(act_type), conv_type_(conv_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string act_type_;
std::string conv_type_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -62,6 +62,7 @@ class Optimizer {
// TODO(Superjomn) Refine the fusion related design to select fusion
// kernels for devices automatically.
"lite_conv_activation_fuse_pass", //
"lite_var_conv_2d_activation_fuse_pass", //
"lite_fc_fuse_pass", //
"lite_shuffle_channel_fuse_pass", //
"lite_transpose_softmax_transpose_fuse_pass", //
......
......@@ -73,6 +73,10 @@ void VarConv2DCompute::PrepareForRun() {
(*conv_param_.paddings.get())[i * 2 + 1],
conv_param_.strides[i]));
}
if (param.fuse_relu) {
conv_param_.activation_param.has_active = true;
conv_param_.activation_param.active_type = lite_api::ActivationType::kRelu;
}
conv_param_.output->Resize({output_shape});
conv_impl_.reset(new lite::cuda::math::CudnnConv2D<PRECISION(kFloat)>);
conv_impl_->init(conv_param_, &context);
......
......@@ -864,6 +864,8 @@ struct VarConv2DParam {
int stride_w;
int kernel_h;
int kernel_w;
bool fuse_relu{false};
};
/// ----------------------- shape operators ----------------------
......
......@@ -48,6 +48,10 @@ bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.kernel_w = opdesc.GetAttr<int>("KernelW");
param_.stride_h = opdesc.GetAttr<int>("StrideH");
param_.stride_w = opdesc.GetAttr<int>("StrideW");
if (opdesc.HasAttr("fuse_relu")) {
param_.fuse_relu = opdesc.GetAttr<bool>("fuse_relu");
}
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册