提交 fb345c46 编写于 作者: C chonwhite

added fc_relu and fc_sigmoid fusion

上级 e5f6ce88
......@@ -48,9 +48,6 @@ class FullyConnectedPE : public PE {
int num = param_.filter->shape().channel();
int chw = param_.filter->shape().num();
// if (num == 2) {
// return;
// }
int height = param_.input->shape().height();
int width = param_.input->shape().width();
......
......@@ -23,13 +23,23 @@ namespace lite {
namespace mir {
void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
#ifdef LITE_WITH_X86 || LITE_WITH_FPGA
fusion::FcFuser fuser(true);
fuser(graph.get());
std::vector<std::string> act_types{};
#ifdef LITE_WITH_X86
act_types.push_back("relu");
#endif
fusion::FcFuser fuser2(false);
fuser2(graph.get());
#ifdef LITE_WITH_FPGA
act_types.push_back("relu");
act_types.push_back("sigmoid");
#endif
act_types.push_back("");
for (int i = 0; i < act_types.size(); i++) {
std::string act_type = act_types[i];
fusion::FcFuser fuser(act_type);
fuser(graph.get());
}
}
} // namespace mir
......
......@@ -41,15 +41,17 @@ void FcFuser::BuildPattern() {
mul->AsIntermediate();
add->AsIntermediate();
if (with_relu_) {
if (act_type_ != "") {
std::cout << "act_type_:" << act_type_ << std::endl;
auto* add_out = VarNode("add_out");
auto* relu = OpNode("relu", "relu");
std::vector<PMNode*> relu_inputs{add_out};
auto* activation = OpNode(act_type_, act_type_);
std::vector<PMNode*> act_inputs{add_out};
add_inputs >> *add >> *add_out;
relu_inputs >> *relu >> *Out;
act_inputs >> *activation >> *Out;
add_out->AsIntermediate();
relu->AsIntermediate();
activation->AsIntermediate();
} else {
std::cout << "act_type_: empty" << std::endl;
add_inputs >> *add >> *Out;
}
}
......@@ -82,8 +84,8 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc.SetAttr(
"in_num_col_dims",
matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));
if (with_relu_) {
op_desc.SetAttr("activation_type", std::string{"relu"});
if (act_type_ != "") {
op_desc.SetAttr("activation_type", act_type_);
}
return op_desc;
}
......
......@@ -25,13 +25,13 @@ namespace fusion {
class FcFuser : public FuseBase {
public:
explicit FcFuser(bool with_relu) : with_relu_(with_relu) {}
explicit FcFuser(std::string act_type) : act_type_(act_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
bool with_relu_;
std::string act_type_ = "";
};
} // namespace fusion
......
......@@ -14,6 +14,8 @@
#pragma once
#include <algorithm>
#include <map>
#include <string>
#include "lite/backends/fpga/KD/float16.hpp"
#include "lite/backends/fpga/KD/pes/relu_pe.hpp"
#include "lite/core/kernel.h"
......@@ -24,6 +26,13 @@ namespace lite {
namespace kernels {
namespace fpga {
static std::map<std::string, zynqmp::ActiveType> activation_map = {
{"relu", zynqmp::TYPE_RELU},
{"relu6", zynqmp::TYPE_RELU6},
{"leaky_relu", zynqmp::TYPE_LEAKY_RELU},
{"sigmoid", zynqmp::TYPE_SIGMOID},
{"", zynqmp::TYPE_NONE}};
class ReluCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public:
......
......@@ -16,6 +16,7 @@
#include <string>
#include "lite/backends/arm/math/funcs.h"
#include "lite/backends/fpga/KD/debugger.hpp"
#include "lite/kernels/fpga/activation_compute.h"
namespace paddle {
namespace lite {
......@@ -29,11 +30,9 @@ void ElementwiseAddCompute::PrepareForRun() {
auto& param = Param<operators::ElementwiseParam>();
param.Out->mutable_data<float16>();
ew_param.inputs = {param.X->ZynqTensor(), param.Y->ZynqTensor()};
ew_param.output = param.Out->ZynqTensor();
ew_param.axis = param.axis;
ew_param.activeParam.type = zynqmp::TYPE_NONE;
pe_.init();
......@@ -50,14 +49,17 @@ void ElementwiseAddCompute::Run() {
void ElementwiseAddActivationCompute::PrepareForRun() {
zynqmp::ElementwiseAddParam& ew_param = pe_.param();
auto& param = Param<operators::FusionElementwiseActivationParam>();
if (param.act_type != "relu") {
if (activation_map.count(param.act_type)) {
ew_param.activeParam.type = activation_map[param.act_type];
} else {
LOG(FATAL) << "unsupported Activation type: " << param.act_type;
}
param.Out->mutable_data<float16>();
ew_param.inputs = {param.X->ZynqTensor(), param.Y->ZynqTensor()};
ew_param.output = param.Out->ZynqTensor();
ew_param.axis = param.axis;
ew_param.activeParam.type = zynqmp::TYPE_RELU;
pe_.init();
pe_.apply();
}
......@@ -76,7 +78,6 @@ void ElementwiseMulCompute::PrepareForRun() {
scale_param.input = param.X->ZynqTensor();
scale_param.output = param.Out->ZynqTensor();
scale_param.activeParam.type = zynqmp::TYPE_NONE;
int channel = scale_param.input->shape().channel();
......@@ -103,9 +104,10 @@ void ElementwiseMulCompute::PrepareForRun() {
void ElementwiseMulCompute::Run() {
auto& param = Param<operators::ElementwiseParam>();
param.Y->ZynqTensor()->flush();
if (!param.Y->persistable()) {
scale_.copyFrom(param.Y->ZynqTensor());
scale_.invalidate();
}
pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR
zynqmp::ScaleParam& scale_param = pe_.param();
......@@ -170,7 +172,10 @@ REGISTER_LITE_KERNEL(elementwise_mul,
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "lite/kernels/fpga/fc_compute.h"
#include "lite/kernels/fpga/activation_compute.h"
#include "lite/backends/fpga/KD/debugger.hpp"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
......@@ -36,6 +38,10 @@ void FcCompute::PrepareForRun() {
fc_param.filter = param.w->ZynqTensor();
fc_param.bias = param.bias->ZynqTensor();
if (activation_map.count(param.activation_type)) {
fc_param.activeParam.type = activation_map[param.activation_type];
}
pe_.init();
pe_.apply();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册