提交 6f271022 编写于 作者: Z zhangyang

add MUL kernel for FPGA track

上级 ad57b923
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef MUL_OP
#include "operators/kernel/mul_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool MulKernel<FPGA, float>::Init(MulParam<FPGA> *param) {
bool relu_enabled = false;
auto input_x = const_cast<LoDTensor *>(param->InputX());
auto filter = const_cast<LoDTensor *>(param->InputY());
auto out = param->Out();
PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == filter->dims()[0],
"Image channel should be equal to weight number");
int channel = (uint32_t)out->dims()[1];
auto bs_ptr =
(float *)fpga::fpga_malloc(2 * channel * sizeof(float)); // NOLINT
for (int i = 0; i < channel; i++) {
bs_ptr[i + channel] = 1;
bs_ptr[i] = 0;
}
int num = (uint32_t)filter->dims()[1];
int chw = (uint32_t)filter->dims()[0];
PADDLE_MOBILE_ENFORCE(
chw == input_x->numel(),
"Filter element num should be equal to IFM element num");
int height = (uint32_t)input_x->dims()[2];
int width = (uint32_t)input_x->dims()[3];
int filter_channel = chw / height / width;
filter->Resize(framework::make_ddim({num, filter_channel, height, width}));
float max_value = fpga::filter_find_max(filter);
fpga::format_fc_filter(filter, max_value);
int element_num_per_div = fpga::get_filter_num_per_div(filter, 1);
fpga::format_bias_scale_array(&bs_ptr, element_num_per_div, channel);
fpga::format_fp16_ofm(out);
fpga::WrapperConvArgs conv_arg = {0};
fpga::fill_conv_arg(&conv_arg, input_x, out, filter, relu_enabled, 1, 1, 1, 0,
0, bs_ptr);
param->SetFpgaArgs(conv_arg);
return true;
}
template <>
void MulKernel<FPGA, float>::Compute(const MulParam<FPGA> &param) const {
fpga::ComputeFpgaConv(param.FpgaArgs());
}
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -61,5 +61,7 @@ REGISTER_OPERATOR_CPU(mul, ops::MulOp); ...@@ -61,5 +61,7 @@ REGISTER_OPERATOR_CPU(mul, ops::MulOp);
#ifdef PADDLE_MOBILE_MALI_GPU #ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(mul, ops::MulOp); REGISTER_OPERATOR_MALI_GPU(mul, ops::MulOp);
#endif #endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(mul, ops::MulOp);
#endif
#endif #endif
...@@ -438,6 +438,15 @@ class MulParam : OpParam { ...@@ -438,6 +438,15 @@ class MulParam : OpParam {
GType *out_; GType *out_;
int x_num_col_dims_; int x_num_col_dims_;
int y_num_col_dims_; int y_num_col_dims_;
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::WrapperConvArgs fpga_conv_args;
public:
const fpga::WrapperConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::WrapperConvArgs &args) { fpga_conv_args = args; }
#endif
}; };
#endif #endif
......
...@@ -18,8 +18,9 @@ static const char *g_resnet_combine = "../models/resnet50"; ...@@ -18,8 +18,9 @@ static const char *g_resnet_combine = "../models/resnet50";
int main() { int main() {
DLOG << paddle_mobile::fpga::open_device(); DLOG << paddle_mobile::fpga::open_device();
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile; paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile;
if (paddle_mobile.Load(std::string(g_resnet_combine) + "/model", // if (paddle_mobile.Load(std::string(g_resnet_combine) + "/model",
std::string(g_resnet_combine) + "/params", true)) { // std::string(g_resnet_combine) + "/params", true)) {
if (paddle_mobile.Load(std::string(g_resnet_combine), true)) {
std::vector<int64_t> dims{1, 3, 224, 224}; std::vector<int64_t> dims{1, 3, 224, 224};
Tensor input_tensor; Tensor input_tensor;
SetupTensor<float>(&input_tensor, {1, 3, 224, 224}, static_cast<float>(0), SetupTensor<float>(&input_tensor, {1, 3, 224, 224}, static_cast<float>(0),
......
...@@ -121,6 +121,7 @@ if (CON GREATER -1) ...@@ -121,6 +121,7 @@ if (CON GREATER -1)
set(FUSION_CONVBNRELU_OP ON) set(FUSION_CONVBNRELU_OP ON)
set(FUSION_CONVBN_OP ON) set(FUSION_CONVBN_OP ON)
set(FUSION_CONVADD_OP ON) set(FUSION_CONVADD_OP ON)
set(MUL_OP ON)
set(FOUND_MATCH ON) set(FOUND_MATCH ON)
endif() endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册