未验证 提交 a4586f48 编写于 作者: W Wilber 提交者: GitHub

[CUDA] [Fuse] Fuse relu. (#4090)

上级 d1258513
......@@ -39,6 +39,7 @@ USE_MIR_PASS(identity_dropout_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_fuse_pass);
USE_MIR_PASS(lite_conv_activation_fuse_pass);
USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass);
USE_MIR_PASS(lite_match_matrix_activation_fuse_pass);
USE_MIR_PASS(lite_elementwise_activation_fuse_pass);
USE_MIR_PASS(lite_quant_dequant_fuse_pass);
USE_MIR_PASS(type_precision_cast_pass);
......
......@@ -29,6 +29,7 @@ lite_cc_library(mir_passes
fusion/__xpu__embedding_with_eltwise_add_fuse_pass.cc
fusion/__xpu__fc_fuse_pass.cc
fusion/__xpu__mmdnn_fuse_pass.cc
fusion/match_matrix_activation_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc
......
......@@ -37,6 +37,9 @@ lite_cc_library(fuse_sequence_pool_concat
lite_cc_library(fuse_scale_activation
SRCS scale_activation_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_match_matrix_activation
SRCS match_matrix_activation_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
......@@ -52,6 +55,7 @@ set(mir_fusers
fuse_interpolate
fuse_sequence_pool_concat
fuse_scale_activation
fuse_match_matrix_activation
CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
......@@ -23,7 +23,7 @@ namespace lite {
namespace mir {
void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
#ifdef LITE_WITH_X86
#if defined(LITE_WITH_X86) || defined(LITE_WITH_CUDA)
#ifdef LITE_WITH_MLU
fusion::FcFuser fuser(false);
fuser(graph.get());
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/match_matrix_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/match_matrix_activation_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void MatchMatrixActFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::MatchMatrixActFuser fuser("relu");
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_match_matrix_activation_fuse_pass,
paddle::lite::mir::MatchMatrixActFusePass)
.BindTargets({TARGET(kCUDA)});
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class MatchMatrixActFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/match_matrix_activation_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void MatchMatrixActFuser::BuildPattern() {
// create nodes.
auto* x = VarNode("x")->assert_is_op_input("match_matrix_tensor", "X");
auto* W = VarNode("W")->assert_is_op_input("match_matrix_tensor", "W");
auto* y = VarNode("y")->assert_is_op_input("match_matrix_tensor", "Y");
auto* mm = OpNode("match_matrix_tensor", "match_matrix_tensor");
auto* mm_out =
VarNode("mm_out")->assert_is_op_output("match_matrix_tensor", "Out");
auto* mm_tmp =
VarNode("mm_tmp")->assert_is_op_output("match_matrix_tensor", "Tmp");
auto* act = OpNode("act", activation_);
auto* out = VarNode("Out")->assert_is_op_output(activation_, "Out");
// create topology.
std::vector<PMNode*> mm_inputs{x, W, y};
std::vector<PMNode*> mm_ouputs{mm_out, mm_tmp};
mm_inputs >> *mm >> mm_ouputs;
// Some op specialities.
mm_out->AsIntermediate();
mm->AsIntermediate();
act->AsIntermediate();
*mm_out >> *act >> *out;
}
void MatchMatrixActFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto mm_op = LiteOpRegistry::Global().Create("match_matrix_tensor");
auto mm = matched.at("match_matrix_tensor")->stmt()->op();
auto* scope = mm->scope();
auto& valid_places = mm->valid_places();
mm_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(mm_op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(matched.at("W"), new_op_node);
IR_NODE_LINK_TO(matched.at("y"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("Out"));
}
cpp::OpDesc MatchMatrixActFuser::GenOpDesc(const key2nodes_t& matched) {
auto op_desc = *matched.at("match_matrix_tensor")->stmt()->op_info();
int dim_t = matched.at("match_matrix_tensor")
->stmt()
->op_info()
->GetAttr<int>("dim_t");
op_desc.mutable_inputs()->clear();
op_desc.mutable_outputs()->clear();
op_desc.SetType("match_matrix_tensor");
op_desc.SetInput("X", {matched.at("x")->arg()->name});
op_desc.SetInput("W", {matched.at("W")->arg()->name});
op_desc.SetInput("Y", {matched.at("y")->arg()->name});
op_desc.SetOutput("Out", {matched.at("Out")->arg()->name});
op_desc.SetOutput("Tmp", {matched.at("mm_tmp")->arg()->name});
op_desc.SetAttr("dim_t", dim_t);
op_desc.SetAttr("fuse_relu", true);
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class MatchMatrixActFuser : public FuseBase {
public:
explicit MatchMatrixActFuser(std::string activation)
: activation_(activation) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string activation_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -25,6 +25,9 @@ void VarConvActivationFuser::BuildPattern() {
// create nodes.
auto* input = VarNode("X")->assert_is_op_input(conv_type_, "X")->AsInput();
auto* filter = VarNode("W")->assert_is_op_input(conv_type_, "W")->AsInput();
auto* column =
VarNode("COLUMN")->assert_is_op_input(conv_type_, "COLUMN")->AsInput();
auto* row = VarNode("ROW")->assert_is_op_input(conv_type_, "ROW")->AsInput();
auto* conv2d = OpNode("var_conv_2d", conv_type_)->AsIntermediate();
......@@ -42,7 +45,7 @@ void VarConvActivationFuser::BuildPattern() {
VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput();
// create topology.
std::vector<PMNode*> conv2d_inputs{filter, input};
std::vector<PMNode*> conv2d_inputs{filter, input, column, row};
conv2d_inputs >> *conv2d >> *conv2d_out >> *act >> *out;
*conv2d >> *conv2d_out_1;
}
......@@ -60,6 +63,8 @@ void VarConvActivationFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(matched.at("X"), new_op_node);
IR_NODE_LINK_TO(matched.at("W"), new_op_node);
IR_NODE_LINK_TO(matched.at("COLUMN"), new_op_node);
IR_NODE_LINK_TO(matched.at("ROW"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
......
......@@ -91,6 +91,7 @@ class Optimizer {
// kernels for devices automatically.
"lite_conv_activation_fuse_pass", //
"lite_var_conv_2d_activation_fuse_pass", //
"lite_match_matrix_activation_fuse_pass", //
"lite_fc_fuse_pass", //
"lite_shuffle_channel_fuse_pass", //
"lite_transpose_softmax_transpose_fuse_pass", //
......
......@@ -52,6 +52,7 @@ __global__ void padding_out(const dtype* src,
const int max_len_r,
const int tl,
const int count,
const bool fuse_relu,
dtype* dst) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int thread_num = blockDim.x * gridDim.x;
......@@ -62,7 +63,13 @@ __global__ void padding_out(const dtype* src,
int r_id = tid % max_len_r;
int cur_len = offset[seq_id + 1] - offset[seq_id];
if (r_id < cur_len) {
dst[tid] = src[(offset[seq_id] + r_id) * tl + tl_id];
if (fuse_relu) {
dst[tid] = src[(offset[seq_id] + r_id) * tl + tl_id] > 0
? src[(offset[seq_id] + r_id) * tl + tl_id]
: 0;
} else {
dst[tid] = src[(offset[seq_id] + r_id) * tl + tl_id];
}
} else {
dst[tid] = 0.f;
}
......@@ -86,6 +93,7 @@ void MatchMatrixTensorCompute::Run() {
auto* tmp = param.tmp;
int dim_t = param.dim_t;
int dim_in = x->dims()[1];
bool fuse_relu = param.fuse_relu;
const auto& offset_l = x->lod()[0];
const auto& offset_r = y->lod()[0];
......@@ -155,6 +163,7 @@ void MatchMatrixTensorCompute::Run() {
max_len_r,
dim_t * len_l,
count,
fuse_relu,
out_data);
out->set_lod(y->lod());
}
......
......@@ -43,6 +43,8 @@ bool SequencePoolOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.MaxIndex = scope->FindVar(opdesc.Output("MaxIndex").front())
->GetMutable<lite::Tensor>();
param_.pool_type = opdesc.GetAttr<std::string>("pooltype");
CHECK(param_.X);
CHECK(param_.Out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册