提交 fb345c46 编写于 作者: C chonwhite

added fc_relu and fc_sigmoid fusion

上级 e5f6ce88
...@@ -48,9 +48,6 @@ class FullyConnectedPE : public PE { ...@@ -48,9 +48,6 @@ class FullyConnectedPE : public PE {
int num = param_.filter->shape().channel(); int num = param_.filter->shape().channel();
int chw = param_.filter->shape().num(); int chw = param_.filter->shape().num();
// if (num == 2) {
// return;
// }
int height = param_.input->shape().height(); int height = param_.input->shape().height();
int width = param_.input->shape().width(); int width = param_.input->shape().width();
......
...@@ -23,13 +23,23 @@ namespace lite { ...@@ -23,13 +23,23 @@ namespace lite {
namespace mir { namespace mir {
void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
#ifdef LITE_WITH_X86 || LITE_WITH_FPGA std::vector<std::string> act_types{};
fusion::FcFuser fuser(true);
fuser(graph.get()); #ifdef LITE_WITH_X86
act_types.push_back("relu");
#endif #endif
fusion::FcFuser fuser2(false); #ifdef LITE_WITH_FPGA
fuser2(graph.get()); act_types.push_back("relu");
act_types.push_back("sigmoid");
#endif
act_types.push_back("");
for (int i = 0; i < act_types.size(); i++) {
std::string act_type = act_types[i];
fusion::FcFuser fuser(act_type);
fuser(graph.get());
}
} }
} // namespace mir } // namespace mir
......
...@@ -41,15 +41,17 @@ void FcFuser::BuildPattern() { ...@@ -41,15 +41,17 @@ void FcFuser::BuildPattern() {
mul->AsIntermediate(); mul->AsIntermediate();
add->AsIntermediate(); add->AsIntermediate();
if (with_relu_) { if (act_type_ != "") {
std::cout << "act_type_:" << act_type_ << std::endl;
auto* add_out = VarNode("add_out"); auto* add_out = VarNode("add_out");
auto* relu = OpNode("relu", "relu"); auto* activation = OpNode(act_type_, act_type_);
std::vector<PMNode*> relu_inputs{add_out}; std::vector<PMNode*> act_inputs{add_out};
add_inputs >> *add >> *add_out; add_inputs >> *add >> *add_out;
relu_inputs >> *relu >> *Out; act_inputs >> *activation >> *Out;
add_out->AsIntermediate(); add_out->AsIntermediate();
relu->AsIntermediate(); activation->AsIntermediate();
} else { } else {
std::cout << "act_type_: empty" << std::endl;
add_inputs >> *add >> *Out; add_inputs >> *add >> *Out;
} }
} }
...@@ -82,8 +84,8 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -82,8 +84,8 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc.SetAttr( op_desc.SetAttr(
"in_num_col_dims", "in_num_col_dims",
matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims")); matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));
if (with_relu_) { if (act_type_ != "") {
op_desc.SetAttr("activation_type", std::string{"relu"}); op_desc.SetAttr("activation_type", act_type_);
} }
return op_desc; return op_desc;
} }
......
...@@ -25,13 +25,13 @@ namespace fusion { ...@@ -25,13 +25,13 @@ namespace fusion {
class FcFuser : public FuseBase { class FcFuser : public FuseBase {
public: public:
explicit FcFuser(bool with_relu) : with_relu_(with_relu) {} explicit FcFuser(std::string act_type) : act_type_(act_type) {}
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private: private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
bool with_relu_; std::string act_type_ = "";
}; };
} // namespace fusion } // namespace fusion
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <map>
#include <string>
#include "lite/backends/fpga/KD/float16.hpp" #include "lite/backends/fpga/KD/float16.hpp"
#include "lite/backends/fpga/KD/pes/relu_pe.hpp" #include "lite/backends/fpga/KD/pes/relu_pe.hpp"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
...@@ -24,6 +26,13 @@ namespace lite { ...@@ -24,6 +26,13 @@ namespace lite {
namespace kernels { namespace kernels {
namespace fpga { namespace fpga {
static std::map<std::string, zynqmp::ActiveType> activation_map = {
{"relu", zynqmp::TYPE_RELU},
{"relu6", zynqmp::TYPE_RELU6},
{"leaky_relu", zynqmp::TYPE_LEAKY_RELU},
{"sigmoid", zynqmp::TYPE_SIGMOID},
{"", zynqmp::TYPE_NONE}};
class ReluCompute class ReluCompute
: public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> { : public KernelLite<TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC)> {
public: public:
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <string> #include <string>
#include "lite/backends/arm/math/funcs.h" #include "lite/backends/arm/math/funcs.h"
#include "lite/backends/fpga/KD/debugger.hpp" #include "lite/backends/fpga/KD/debugger.hpp"
#include "lite/kernels/fpga/activation_compute.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -29,11 +30,9 @@ void ElementwiseAddCompute::PrepareForRun() { ...@@ -29,11 +30,9 @@ void ElementwiseAddCompute::PrepareForRun() {
auto& param = Param<operators::ElementwiseParam>(); auto& param = Param<operators::ElementwiseParam>();
param.Out->mutable_data<float16>(); param.Out->mutable_data<float16>();
ew_param.inputs = {param.X->ZynqTensor(), param.Y->ZynqTensor()}; ew_param.inputs = {param.X->ZynqTensor(), param.Y->ZynqTensor()};
ew_param.output = param.Out->ZynqTensor(); ew_param.output = param.Out->ZynqTensor();
ew_param.axis = param.axis; ew_param.axis = param.axis;
ew_param.activeParam.type = zynqmp::TYPE_NONE; ew_param.activeParam.type = zynqmp::TYPE_NONE;
pe_.init(); pe_.init();
...@@ -50,14 +49,17 @@ void ElementwiseAddCompute::Run() { ...@@ -50,14 +49,17 @@ void ElementwiseAddCompute::Run() {
void ElementwiseAddActivationCompute::PrepareForRun() { void ElementwiseAddActivationCompute::PrepareForRun() {
zynqmp::ElementwiseAddParam& ew_param = pe_.param(); zynqmp::ElementwiseAddParam& ew_param = pe_.param();
auto& param = Param<operators::FusionElementwiseActivationParam>(); auto& param = Param<operators::FusionElementwiseActivationParam>();
if (param.act_type != "relu") {
if (activation_map.count(param.act_type)) {
ew_param.activeParam.type = activation_map[param.act_type];
} else {
LOG(FATAL) << "unsupported Activation type: " << param.act_type; LOG(FATAL) << "unsupported Activation type: " << param.act_type;
} }
param.Out->mutable_data<float16>(); param.Out->mutable_data<float16>();
ew_param.inputs = {param.X->ZynqTensor(), param.Y->ZynqTensor()}; ew_param.inputs = {param.X->ZynqTensor(), param.Y->ZynqTensor()};
ew_param.output = param.Out->ZynqTensor(); ew_param.output = param.Out->ZynqTensor();
ew_param.axis = param.axis; ew_param.axis = param.axis;
ew_param.activeParam.type = zynqmp::TYPE_RELU;
pe_.init(); pe_.init();
pe_.apply(); pe_.apply();
} }
...@@ -76,7 +78,6 @@ void ElementwiseMulCompute::PrepareForRun() { ...@@ -76,7 +78,6 @@ void ElementwiseMulCompute::PrepareForRun() {
scale_param.input = param.X->ZynqTensor(); scale_param.input = param.X->ZynqTensor();
scale_param.output = param.Out->ZynqTensor(); scale_param.output = param.Out->ZynqTensor();
scale_param.activeParam.type = zynqmp::TYPE_NONE; scale_param.activeParam.type = zynqmp::TYPE_NONE;
int channel = scale_param.input->shape().channel(); int channel = scale_param.input->shape().channel();
...@@ -103,9 +104,10 @@ void ElementwiseMulCompute::PrepareForRun() { ...@@ -103,9 +104,10 @@ void ElementwiseMulCompute::PrepareForRun() {
void ElementwiseMulCompute::Run() { void ElementwiseMulCompute::Run() {
auto& param = Param<operators::ElementwiseParam>(); auto& param = Param<operators::ElementwiseParam>();
param.Y->ZynqTensor()->flush(); if (!param.Y->persistable()) {
scale_.copyFrom(param.Y->ZynqTensor()); scale_.copyFrom(param.Y->ZynqTensor());
scale_.invalidate(); scale_.invalidate();
}
pe_.dispatch(); pe_.dispatch();
#ifdef FPGA_PRINT_TENSOR #ifdef FPGA_PRINT_TENSOR
zynqmp::ScaleParam& scale_param = pe_.param(); zynqmp::ScaleParam& scale_param = pe_.param();
...@@ -170,7 +172,10 @@ REGISTER_LITE_KERNEL(elementwise_mul, ...@@ -170,7 +172,10 @@ REGISTER_LITE_KERNEL(elementwise_mul,
{LiteType::GetTensorTy(TARGET(kFPGA), {LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16), PRECISION(kFP16),
DATALAYOUT(kNHWC))}) DATALAYOUT(kNHWC))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y",
{LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out", .BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kFPGA), {LiteType::GetTensorTy(TARGET(kFPGA),
PRECISION(kFP16), PRECISION(kFP16),
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#include "lite/kernels/fpga/fc_compute.h" #include "lite/kernels/fpga/fc_compute.h"
#include "lite/kernels/fpga/activation_compute.h"
#include "lite/backends/fpga/KD/debugger.hpp" #include "lite/backends/fpga/KD/debugger.hpp"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/core/type_system.h" #include "lite/core/type_system.h"
...@@ -36,6 +38,10 @@ void FcCompute::PrepareForRun() { ...@@ -36,6 +38,10 @@ void FcCompute::PrepareForRun() {
fc_param.filter = param.w->ZynqTensor(); fc_param.filter = param.w->ZynqTensor();
fc_param.bias = param.bias->ZynqTensor(); fc_param.bias = param.bias->ZynqTensor();
if (activation_map.count(param.activation_type)) {
fc_param.activeParam.type = activation_map[param.activation_type];
}
pe_.init(); pe_.init();
pe_.apply(); pe_.apply();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册