From ab7155f86997442e930f62882e92a113b1ea3df9 Mon Sep 17 00:00:00 2001 From: zhangyang Date: Tue, 7 Aug 2018 22:56:01 +0800 Subject: [PATCH] add PerformBypass inside FeedOp --- src/operators/feed_op.h | 35 ++++++++++++++++--- .../kernel/fpga/conv_add_bn_kernel.cpp | 4 +-- .../kernel/fpga/conv_add_bn_relu_kernel.cpp | 4 +-- .../kernel/fpga/conv_add_relu_kernel.cpp | 4 +-- src/operators/op_param.h | 10 ++++++ 5 files changed, 47 insertions(+), 10 deletions(-) diff --git a/src/operators/feed_op.h b/src/operators/feed_op.h index e45ad38fd6..8056e78aed 100644 --- a/src/operators/feed_op.h +++ b/src/operators/feed_op.h @@ -30,18 +30,45 @@ class FeedOp : public framework::OperatorBase { : framework::OperatorBase(type, inputs, outputs, attrs, scope), param_(inputs, outputs, attrs, scope.get()) {} - void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); } - void Init() {} + protected: + FeedParam param_; + +#ifdef PADDLE_MOBILE_FPGA + void RunImpl() const { fpga::PerformBypass(param_.FpgaArgs()); } + + void Init() { + const Tensor *input = param_.InputX(); + auto input_ptr = input->data(); + Tensor *output = param_.Out(); + auto output_ptr = output->mutable_data(); + fpga::BypassArgs args; + args.convert_type = fpga::DATA_FP32_TO_FP16; + args.layout_type = fpga::LAYOUT_CHW_TO_HWC; + args.image.address = (void *)input_ptr; + args.image.channels = input->dims()[1]; + args.image.height = input->dims()[2]; + args.image.width = input->dims()[3]; + args.output.address = output_ptr; + param_.SetFpgaArgs(args); + } void InferShape() const { auto out_dims = param_.Out()->dims(); out_dims[0] = param_.BatchSize(); param_.Out()->Resize(out_dims); + param_.Out()->ShareDataWith(*param_.InputX()); // TODO How to handle fp16 } +#else + void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); } - protected: - FeedParam param_; + void Init() {} + + void InferShape() const { + auto out_dims = param_.Out()->dims(); + out_dims[0] = param_.BatchSize(); + param_.Out()->Resize(out_dims); +#endif }; } // namespace operators diff --git a/src/operators/kernel/fpga/conv_add_bn_kernel.cpp b/src/operators/kernel/fpga/conv_add_bn_kernel.cpp index 6f9da6bc1d..6719db3a80 100644 --- a/src/operators/kernel/fpga/conv_add_bn_kernel.cpp +++ b/src/operators/kernel/fpga/conv_add_bn_kernel.cpp @@ -24,13 +24,13 @@ template <> bool ConvAddBNKernel::Init(FusionConvAddBNParam *param) { bool relu_enabled = false; const Tensor *input = param->Input(); - auto input_ptr = input->data(); + auto input_ptr = input->data(); const Tensor *bias = param->Bias(); auto bias_ptr = bias->data(); const Tensor *filter = param->Filter(); auto filter_ptr = filter->data(); Tensor *out = param->Output(); - auto out_ptr = out->mutable_data(); + auto out_ptr = out->mutable_data(); auto bn_mean_ptr = param->InputMean()->data(); auto bn_var_ptr = param->InputVariance()->data(); auto bn_scale_ptr = param->InputScale()->data(); diff --git a/src/operators/kernel/fpga/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/fpga/conv_add_bn_relu_kernel.cpp index 66a593df84..2f80ec9742 100644 --- a/src/operators/kernel/fpga/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/fpga/conv_add_bn_relu_kernel.cpp @@ -24,13 +24,13 @@ template <> bool ConvAddBNReluKernel::Init(FusionConvAddBNReluParam *param) { bool relu_enabled = true; const Tensor *input = param->Input(); - auto input_ptr = input->data(); + auto input_ptr = input->data(); const Tensor *bias = param->Bias(); auto bias_ptr = bias->data(); const Tensor *filter = param->Filter(); auto filter_ptr = filter->data(); Tensor *out = param->Output(); - auto out_ptr = out->mutable_data(); + auto out_ptr = out->mutable_data(); auto bn_mean_ptr = param->InputMean()->data(); auto bn_var_ptr = param->InputVariance()->data(); auto bn_scale_ptr = param->InputScale()->data(); diff --git a/src/operators/kernel/fpga/conv_add_relu_kernel.cpp b/src/operators/kernel/fpga/conv_add_relu_kernel.cpp index 9692bcef87..a20f4e4837 100644 --- a/src/operators/kernel/fpga/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/fpga/conv_add_relu_kernel.cpp @@ -24,13 +24,13 @@ template <> bool ConvAddReluKernel::Init(FusionConvAddReluParam *param) { bool relu_enabled = true; const Tensor *input = param->Input(); - auto input_ptr = input->data(); + auto input_ptr = input->data(); const Tensor *bias = param->Bias(); auto bias_ptr = bias->data(); const Tensor *filter = param->Filter(); auto filter_ptr = filter->data(); Tensor *out = param->Output(); - auto out_ptr = out->mutable_data(); + auto out_ptr = out->mutable_data(); PADDLE_MOBILE_ENFORCE(input->dims()[1] == bias->dims()[0], "Image channel should be equal to bias number"); diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 0821ab8c32..844d2f1068 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -665,6 +665,16 @@ class FeedParam : public OpParam { Tensor *input_x_; Tensor *out_; int batch_size; + +#ifdef PADDLE_MOBILE_FPGA + + private: + fpga::BypassArgs fpga_bypass_args; + + public: + const fpga::BypassArgs &FpgaArgs() const { return fpga_bypass_args; } + void SetFpgaArgs(const fpga::BypassArgs &args) { fpga_bypass_args = args; } +#endif }; class FetchParam : public OpParam { -- GitLab