diff --git a/lite/backends/fpga/KD/pes/fully_connected_pe.hpp b/lite/backends/fpga/KD/pes/fully_connected_pe.hpp index 01e16a454af8f14c06b7d62fbefe9b29cfef2850..e5d4c128ad089207fa7e7c7651466526438b7450 100644 --- a/lite/backends/fpga/KD/pes/fully_connected_pe.hpp +++ b/lite/backends/fpga/KD/pes/fully_connected_pe.hpp @@ -48,9 +48,6 @@ class FullyConnectedPE : public PE { int num = param_.filter->shape().channel(); int chw = param_.filter->shape().num(); - // if (num == 2) { - // return; - // } int height = param_.input->shape().height(); int width = param_.input->shape().width(); diff --git a/lite/core/mir/fusion/fc_fuse_pass.cc b/lite/core/mir/fusion/fc_fuse_pass.cc index a4df3a143a5ef3569e74d4401cf75ab5d8c789c7..447fc264f0cdf785942644c96d029f0e13d92f20 100644 --- a/lite/core/mir/fusion/fc_fuse_pass.cc +++ b/lite/core/mir/fusion/fc_fuse_pass.cc @@ -23,13 +23,23 @@ namespace lite { namespace mir { void FcFusePass::Apply(const std::unique_ptr& graph) { -#ifdef LITE_WITH_X86 || LITE_WITH_FPGA - fusion::FcFuser fuser(true); - fuser(graph.get()); + std::vector act_types{}; + +#ifdef LITE_WITH_X86 + act_types.push_back("relu"); +#endif + +#ifdef LITE_WITH_FPGA + act_types.push_back("relu"); + act_types.push_back("sigmoid"); #endif - fusion::FcFuser fuser2(false); - fuser2(graph.get()); + act_types.push_back(""); + for (int i = 0; i < act_types.size(); i++) { + std::string act_type = act_types[i]; + fusion::FcFuser fuser(act_type); + fuser(graph.get()); + } } } // namespace mir diff --git a/lite/core/mir/fusion/fc_fuser.cc b/lite/core/mir/fusion/fc_fuser.cc index 3c99131083d37ea2c8511ed136bff17c891529af..34bcf43d18967d65564f7ce64b6b821057313fc7 100644 --- a/lite/core/mir/fusion/fc_fuser.cc +++ b/lite/core/mir/fusion/fc_fuser.cc @@ -41,15 +41,17 @@ void FcFuser::BuildPattern() { mul->AsIntermediate(); add->AsIntermediate(); - if (with_relu_) { + if (act_type_ != "") { + std::cout << "act_type_:" << act_type_ << std::endl; auto* add_out = VarNode("add_out"); - auto* relu = OpNode("relu", "relu"); - std::vector relu_inputs{add_out}; + auto* activation = OpNode(act_type_, act_type_); + std::vector act_inputs{add_out}; add_inputs >> *add >> *add_out; - relu_inputs >> *relu >> *Out; + act_inputs >> *activation >> *Out; add_out->AsIntermediate(); - relu->AsIntermediate(); + activation->AsIntermediate(); } else { + std::cout << "act_type_: empty" << std::endl; add_inputs >> *add >> *Out; } } @@ -82,8 +84,8 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { op_desc.SetAttr( "in_num_col_dims", matched.at("mul")->stmt()->op_info()->GetAttr("x_num_col_dims")); - if (with_relu_) { - op_desc.SetAttr("activation_type", std::string{"relu"}); + if (act_type_ != "") { + op_desc.SetAttr("activation_type", act_type_); } return op_desc; } diff --git a/lite/core/mir/fusion/fc_fuser.h b/lite/core/mir/fusion/fc_fuser.h index 6cb08f41574b67df1c78fa296d2d395771a66ee1..f38b9f6dbe0be0f97123e48282c111a717f88ca5 100644 --- a/lite/core/mir/fusion/fc_fuser.h +++ b/lite/core/mir/fusion/fc_fuser.h @@ -25,13 +25,13 @@ namespace fusion { class FcFuser : public FuseBase { public: - explicit FcFuser(bool with_relu) : with_relu_(with_relu) {} + explicit FcFuser(std::string act_type) : act_type_(act_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; - bool with_relu_; + std::string act_type_ = ""; }; } // namespace fusion diff --git a/lite/kernels/fpga/activation_compute.h b/lite/kernels/fpga/activation_compute.h index 796d54413f0fd7d434bc660c62b536b1a1eedd4b..5cc431e2d41e2de1d841ee386de3aae4434e3865 100644 --- a/lite/kernels/fpga/activation_compute.h +++ b/lite/kernels/fpga/activation_compute.h @@ -14,6 +14,8 @@ #pragma once #include +#include +#include #include "lite/backends/fpga/KD/float16.hpp" #include "lite/backends/fpga/KD/pes/relu_pe.hpp" #include "lite/core/kernel.h" @@ -24,6 +26,13 @@ namespace lite { namespace kernels { namespace fpga { +static std::map activation_map = { + {"relu", zynqmp::TYPE_RELU}, + {"relu6", zynqmp::TYPE_RELU6}, + {"leaky_relu", zynqmp::TYPE_LEAKY_RELU}, + {"sigmoid", zynqmp::TYPE_SIGMOID}, + {"", zynqmp::TYPE_NONE}}; + class ReluCompute : public KernelLite { public: diff --git a/lite/kernels/fpga/elementwise_compute.cc b/lite/kernels/fpga/elementwise_compute.cc index 39780d82276188b141e31d89466fbe09434393aa..0c9df759498b2b4224729890b4dbd458da36f40c 100755 --- a/lite/kernels/fpga/elementwise_compute.cc +++ b/lite/kernels/fpga/elementwise_compute.cc @@ -16,6 +16,7 @@ #include #include "lite/backends/arm/math/funcs.h" #include "lite/backends/fpga/KD/debugger.hpp" +#include "lite/kernels/fpga/activation_compute.h" namespace paddle { namespace lite { @@ -29,11 +30,9 @@ void ElementwiseAddCompute::PrepareForRun() { auto& param = Param(); param.Out->mutable_data(); - ew_param.inputs = {param.X->ZynqTensor(), param.Y->ZynqTensor()}; ew_param.output = param.Out->ZynqTensor(); ew_param.axis = param.axis; - ew_param.activeParam.type = zynqmp::TYPE_NONE; pe_.init(); @@ -50,14 +49,17 @@ void ElementwiseAddCompute::Run() { void ElementwiseAddActivationCompute::PrepareForRun() { zynqmp::ElementwiseAddParam& ew_param = pe_.param(); auto& param = Param(); - if (param.act_type != "relu") { + + if (activation_map.count(param.act_type)) { + ew_param.activeParam.type = activation_map[param.act_type]; + } else { LOG(FATAL) << "unsupported Activation type: " << param.act_type; } + param.Out->mutable_data(); ew_param.inputs = {param.X->ZynqTensor(), param.Y->ZynqTensor()}; ew_param.output = param.Out->ZynqTensor(); ew_param.axis = param.axis; - ew_param.activeParam.type = zynqmp::TYPE_RELU; pe_.init(); pe_.apply(); } @@ -76,7 +78,6 @@ void ElementwiseMulCompute::PrepareForRun() { scale_param.input = param.X->ZynqTensor(); scale_param.output = param.Out->ZynqTensor(); - scale_param.activeParam.type = zynqmp::TYPE_NONE; int channel = scale_param.input->shape().channel(); @@ -103,9 +104,10 @@ void ElementwiseMulCompute::PrepareForRun() { void ElementwiseMulCompute::Run() { auto& param = Param(); - param.Y->ZynqTensor()->flush(); - scale_.copyFrom(param.Y->ZynqTensor()); - scale_.invalidate(); + if (!param.Y->persistable()) { + scale_.copyFrom(param.Y->ZynqTensor()); + scale_.invalidate(); + } pe_.dispatch(); #ifdef FPGA_PRINT_TENSOR zynqmp::ScaleParam& scale_param = pe_.param(); @@ -170,7 +172,10 @@ REGISTER_LITE_KERNEL(elementwise_mul, {LiteType::GetTensorTy(TARGET(kFPGA), PRECISION(kFP16), DATALAYOUT(kNHWC))}) - .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Y", + {LiteType::GetTensorTy(TARGET(kFPGA), + PRECISION(kFP16), + DATALAYOUT(kNHWC))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kFPGA), PRECISION(kFP16), diff --git a/lite/kernels/fpga/fc_compute.cc b/lite/kernels/fpga/fc_compute.cc index 0c76bf0b41e45ad0bcaa10e97011e26449a3ad7d..c777d88a413e07cd612de1efcaf0ff6ff0f8af59 100644 --- a/lite/kernels/fpga/fc_compute.cc +++ b/lite/kernels/fpga/fc_compute.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "lite/kernels/fpga/fc_compute.h" +#include "lite/kernels/fpga/activation_compute.h" + #include "lite/backends/fpga/KD/debugger.hpp" #include "lite/core/op_registry.h" #include "lite/core/type_system.h" @@ -36,6 +38,10 @@ void FcCompute::PrepareForRun() { fc_param.filter = param.w->ZynqTensor(); fc_param.bias = param.bias->ZynqTensor(); + if (activation_map.count(param.activation_type)) { + fc_param.activeParam.type = activation_map[param.activation_type]; + } + pe_.init(); pe_.apply(); }