未验证 提交 2cda8b9b 编写于 作者: qnqinan's avatar qnqinan 提交者: GitHub

Merge pull request #1485 from qnqinan/develop

add ops of marker2 in FPGA track fixed#1484
......@@ -58,6 +58,7 @@ REGISTER_OPERATOR_CPU(relu6, ops::Relu6Op);
REGISTER_OPERATOR_MALI_GPU(relu, ops::ReluOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(relu, ops::ReluOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(relu, ops::ReluOp);
......
......@@ -36,6 +36,7 @@ REGISTER_OPERATOR_CPU(elementwise_mul, ops::ElementwiseMulOp);
REGISTER_OPERATOR_MALI_GPU(elementwise_mul, ops::ElementwiseMulOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(elementwise_mul, ops::ElementwiseMulOp);
#endif
#endif
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_FC_RELU_OP
#ifdef FUSION_FCRELU_OP
#include "operators/fusion_fc_relu_op.h"
namespace paddle_mobile {
......
......@@ -17,7 +17,6 @@ limitations under the License. */
#pragma once
#include "framework/operator.h"
#include "operators/math/elementwise_op_function.h"
#include "operators/op_param.h"
namespace paddle_mobile {
......
......@@ -15,18 +15,21 @@ limitations under the License. */
#include "operators/kernel/elementwise_add_kernel.h"
#include <string>
#include "fpga/V1/api.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ElementwiseAddKernel<FPGA, float>::Init(ElementwiseAddParam<FPGA> *param) {
// bool relu_enabled = false;
auto *input_y = const_cast<LoDTensor *>(param->InputY());
auto *out = param->Out();
if (input_y->type() != typeid(float)) {
paddle_mobile::fpga::ActivationType activation_enable =
paddle_mobile::fpga::NONE;
int16_t leaky_relu_negative_slope = 0;
auto *input_x = const_cast<LoDTensor *>(param->InputX());
auto *input_y = const_cast<LoDTensor *>(param->InputY());
auto *out = param->Out();
auto input_x_ptr = input_x->data<half>();
auto input_y_ptr = input_y->data<half>();
fpga::format_fp16_ofm(out);
......@@ -57,13 +60,131 @@ bool ElementwiseAddKernel<FPGA, float>::Init(ElementwiseAddParam<FPGA> *param) {
ewaddArgs.output.address = out_ptr;
fpga::expand_EW_arg(&ewaddArgs);
param->SetFpgaArgs(ewaddArgs);
} else {
param->float_input_x.Resize(param->InputX()->dims());
param->float_input_x.init(typeid(float));
fpga::format_fp32_ofm(&(param->float_input_x));
param->float_out.Resize(param->InputX()->dims());
// param->float_out.init(typeid(float));
param->float_out.mutable_data<float>(param->InputX()->dims());
fpga::format_fp32_ofm(&(param->float_out));
fpga::format_fp16_ofm(out);
}
return true;
}
inline void ElementwiseAddCompute(const ElementwiseAddParam<FPGA> &param) {
auto input_x = param.float_input_x;
auto input_y = param.InputY();
auto Out = param.float_out;
int axis = param.Axis();
const auto &x_dims = input_x.dims();
const auto &y_dims = input_y->dims();
/// axis = -1 represent the last dimensions.
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
size_t batch = 1;
size_t channels = 1;
size_t elementwise_num = 1;
for (int i = 0; i < axis; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
elementwise_num *= x_dims[i];
}
const float *bias_data = input_y->data<float>();
const float *input_data = input_x.data<float>();
float *output_data = Out.mutable_data<float>();
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
size_t offset = (i * channels + j) * elementwise_num;
const float *input = input_data + offset;
const float bias = bias_data[j];
float *output = output_data + offset;
// DLOG << "output address: "<< output;
for (int k = 0; k < elementwise_num; ++k) {
output[k] = input[k] + bias;
// DLOG << "output[" << k << "]= " << output[k] ;
}
}
}
}
template <>
void ElementwiseAddKernel<FPGA, float>::Compute(
const ElementwiseAddParam<FPGA> &param) {
auto input_y = const_cast<LoDTensor *>(param.InputY());
if (input_y->type() != typeid(float)) {
fpga::ComputeFpgaEWAdd(param.FpgaArgs());
} else {
auto input_x = const_cast<LoDTensor *>(param.InputX());
auto intput_x_float = const_cast<Tensor *>(&(param.float_input_x));
fpga::BypassArgs args = {fpga::DATA_TYPE_FP16};
args.input_data_type = fpga::DATA_TYPE_FP16;
args.output_data_type = fpga::DATA_TYPE_FP32;
args.input_layout_type = fpga::LAYOUT_CHW;
args.output_layout_type = fpga::LAYOUT_HWC;
args.image.address = input_x->data<half>();
args.image.channels = (uint32_t)(input_x->fpga_data_num);
args.image.height = 1;
args.image.width = 1;
args.image.pad_height = 0;
args.image.pad_width = 0;
args.output.address = intput_x_float->data<float>();
args.output.scale_address = intput_x_float->scale;
// fpga::fpga_flush(input_x->data<half>(),input_x->fpga_data_num *
// sizeof(half));
fpga::PerformBypass(args);
fpga::fpga_invalidate(args.output.address,
input_x->fpga_data_num * sizeof(float));
// just for test
/* {
static int cnt = 0;
if(cnt == 0){
std::string str= "first_bypass_data";
float rslt = 0.0f;
fpga::savefile(str, args.output.address, input_x->fpga_data_num,
rslt); cnt++;
}
}*/
ElementwiseAddCompute(param);
auto out_float = const_cast<Tensor *>(&(param.float_out));
DLOG << "out float: " << out_float->data<float>();
fpga::fpga_flush(out_float->data<float>(),
input_x->fpga_data_num * sizeof(float));
// just for test
/*{
static int cnt = 0;
if(cnt == 0){
std::string str= "ew_output_data";
float rslt = 0.0f;
fpga::savefile(str, out_float->data<float>(), input_x->fpga_data_num,
rslt); cnt++;
}
}*/
auto Out = param.Out();
args.input_data_type = fpga::DATA_TYPE_FP32;
args.output_data_type = fpga::DATA_TYPE_FP16;
args.input_layout_type = fpga::LAYOUT_CHW;
args.output_layout_type = fpga::LAYOUT_HWC;
args.image.address = out_float->data<float>();
args.image.channels = (uint32_t)(input_x->fpga_data_num);
args.image.height = 1;
args.image.width = 1;
args.image.pad_height = 0;
args.image.pad_width = 0;
args.output.address = Out->data<half>();
args.output.scale_address = Out->scale;
fpga::PerformBypass(args);
}
}
} // namespace operators
} // namespace paddle_mobile
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISEMUL_OP
#include "operators/kernel/elementwise_mul_kernel.h"
#include "operators/math/elementwise_op_function.h"
namespace paddle_mobile {
namespace operators {
template <typename T>
struct MulFunctor {
inline T operator()(T a, T b) const { return a * b; }
};
template <>
bool ElementwiseMulKernel<FPGA, float>::Init(ElementwiseMulParam<FPGA> *param) {
param->float_input_x.Resize(param->InputX()->dims());
param->float_input_x.init(typeid(float));
fpga::format_fp32_ofm(&(param->float_input_x));
param->float_out.Resize(param->InputX()->dims());
param->float_out.init(typeid(float));
fpga::format_fp32_ofm(&(param->float_out));
auto *out = param->Out();
fpga::format_fp16_ofm(out);
return true;
}
template <>
void ElementwiseMulKernel<FPGA, float>::Compute(
const ElementwiseMulParam<FPGA> &param) {
auto input_x = const_cast<LoDTensor *>(param.InputX());
auto intput_x_float = const_cast<Tensor *>(&(param.float_input_x));
// auto intput_x_32_ptr =
// const_cast<float*>(param.float_input_x.data<float>());
fpga::BypassArgs args = {fpga::DATA_TYPE_FP16};
args.input_data_type = fpga::DATA_TYPE_FP16;
args.output_data_type = fpga::DATA_TYPE_FP32;
args.input_layout_type = fpga::LAYOUT_CHW;
args.output_layout_type = fpga::LAYOUT_HWC;
args.image.address = input_x->data<half>();
args.image.channels = (uint32_t)(input_x->fpga_data_num);
args.image.height = 1;
args.image.width = 1;
args.image.pad_height = 0;
args.image.pad_width = 0;
args.output.address = intput_x_float->data<float>();
args.output.scale_address = intput_x_float->scale;
fpga::PerformBypass(args);
fpga::fpga_invalidate(args.output.address,
input_x->fpga_data_num * sizeof(float));
auto input_y = param.InputY();
int axis = param.Axis();
auto out_float = const_cast<Tensor *>(&(param.float_out));
ElementwiseComputeEx<MulFunctor<float>, float>(
intput_x_float, input_y, axis, MulFunctor<float>(), out_float);
fpga::fpga_flush(out_float->data<float>(),
input_x->fpga_data_num * sizeof(float));
Tensor *Out = param.Out();
args.input_data_type = fpga::DATA_TYPE_FP32;
args.output_data_type = fpga::DATA_TYPE_FP16;
args.input_layout_type = fpga::LAYOUT_CHW;
args.output_layout_type = fpga::LAYOUT_HWC;
args.image.address = out_float->data<float>();
args.image.channels = (uint32_t)(Out->fpga_data_num);
args.image.height = 1;
args.image.width = 1;
args.image.pad_height = 0;
args.image.pad_width = 0;
args.output.address = Out->data<half>();
args.output.scale_address = Out->scale;
fpga::PerformBypass(args);
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -62,15 +62,27 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
output->ShareDataWith(*input);
return;
}
fpga::BypassArgs args = param.fpga_bypass_args;
auto input_address = (input->data<half>());
args.image.address = static_cast<void *>(input_address);
float *outdata_ptr =
reinterpret_cast<float *>(param.fpga_bypass_args.output.address);
const int num_th = 32;
if ((param.Out()->fpga_data_num) < num_th) {
fpga::fpga_invalidate(input_address, (input->fpga_data_num) * sizeof(half));
for (int idx = 0; idx < product(input->dims()); ++idx) {
outdata_ptr[idx] = fpga::fp16_2_fp32(input_address[idx]);
}
return;
}
fpga::PerformBypass(args);
auto outC = param.Out()->dims()[1];
auto outH = param.Out()->dims()[2];
auto outW = param.Out()->dims()[3];
float *outdata_ptr =
reinterpret_cast<float *>(param.fpga_bypass_args.output.address);
fpga::fpga_invalidate(param.fpga_bypass_args.output.address,
param.Out()->fpga_data_num * sizeof(float));
......
......@@ -30,8 +30,8 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam<FPGA> *param) {
auto input_z_ptr = input_z->data<float>();
auto out = param->Out();
PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == filter->dims()[0],
"Image channel should be equal to weight number");
// PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == filter->dims()[0],
// "Image channel should be equal to weight number");
int channel = (uint32_t)out->dims()[1];
auto bs_ptr =
(float *)fpga::fpga_malloc(2 * channel * sizeof(float)); // NOLINT
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_FCRELU_OP
#include "operators/kernel/fc_relu_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam<FPGA> *param) {
// bool relu_enabled = false;
paddle_mobile::fpga::ActivationType activation_enable =
paddle_mobile::fpga::LEAKYRELU;
int16_t leaky_relu_negative_slope = 0;
auto input_x = const_cast<LoDTensor *>(param->InputX());
auto filter = const_cast<Tensor *>(param->InputY());
const Tensor *input_z = param->InputZ();
auto input_z_ptr = input_z->data<float>();
auto out = param->Out();
// PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == filter->dims()[0],
// "Image channel should be equal to weight number");
int channel = (uint32_t)out->dims()[1];
auto bs_ptr =
(float *)fpga::fpga_malloc(2 * channel * sizeof(float)); // NOLINT
for (int i = 0; i < channel; i++) {
bs_ptr[i + channel] = 1;
bs_ptr[i] = input_z_ptr[i];
}
int num = (uint32_t)filter->dims()[1];
int chw = (uint32_t)filter->dims()[0];
PADDLE_MOBILE_ENFORCE(
chw == input_x->numel(),
"Filter element num should be equal to IFM element num");
int height = (uint32_t)input_x->dims()[2];
int width = (uint32_t)input_x->dims()[3];
int filter_channel = chw / height / width;
out->Resize(framework::make_ddim({1, channel, 1, 1}));
filter->Resize(framework::make_ddim({num, filter_channel, height, width}));
float max_value = fpga::filter_find_max(filter);
fpga::format_fc_filter(filter, max_value);
int element_num_per_div = fpga::get_filter_num_per_div(filter, 1);
fpga::format_bias_scale_array(&bs_ptr, element_num_per_div, channel);
fpga::format_fp16_ofm(out);
fpga::SplitConvArgs conv_arg = {0};
fpga::fill_split_arg(&conv_arg, input_x, out, filter, activation_enable,
leaky_relu_negative_slope, 1, 1, 1, 0, 0, bs_ptr);
param->SetFpgaArgs(conv_arg);
return true;
}
template <>
void FusionFcReluKernel<FPGA, float>::Compute(
const FusionFcReluParam<FPGA> &param) {
fpga::ComputeFpgaConv(param.FpgaArgs());
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -503,8 +503,10 @@ void ProposalKernel<FPGA, float>::Compute(const ProposalParam<FPGA> &param) {
auto score_index = *(param.score_index_.get());
int pre_nms_top_n = param.pre_nms_topn_;
int post_nms_top_n = 100; // param.post_nms_topn_;
float nms_thresh = param.nms_thresh_;
int post_nms_top_n = param.post_nms_topn_;
// DLOG << " param.post_nms_topn_ : " << param.post_nms_topn_;
float nms_thresh = param.nms_thresh_ / 2.0f;
float min_size = param.min_size_;
float eta = param.eta_;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RELU_OP
#include "operators/kernel/activation_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ReluKernel<FPGA, float>::Init(ReluParam<FPGA> *param) {
param->Out()->ShareDataWith(*param->InputX());
return true;
}
template <>
void ReluKernel<FPGA, float>::Compute(const ReluParam<FPGA> &param) {}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "op_param.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
......@@ -47,6 +47,9 @@ template class ConvParam<GPU_MALI>;
template class ElementwiseAddParam<CPU>;
template class ElementwiseAddParam<FPGA>;
template class ElementwiseAddParam<GPU_MALI>;
template class ElementwiseMulParam<CPU>;
template class ElementwiseMulParam<FPGA>;
template class ElementwiseMulParam<GPU_MALI>;
#ifdef MUL_OP
template class MulParam<CPU>;
......
......@@ -563,6 +563,10 @@ class ElementwiseAddParam : public OpParam {
public:
const fpga::EWAddArgs &FpgaArgs() const { return fpga_EW_add_args; }
void SetFpgaArgs(const fpga::EWAddArgs &args) { fpga_EW_add_args = args; }
public:
Tensor float_input_x, float_out;
#endif
};
......@@ -596,6 +600,12 @@ class ElementwiseMulParam : public OpParam {
GType *input_y_;
GType *out_;
int axis_;
#ifdef PADDLE_MOBILE_FPGA
public:
Tensor float_input_x, float_out;
#endif
};
#endif
......
......@@ -83,6 +83,8 @@ if (CON GREATER -1)
ADD_EXECUTABLE(test-rfcn-api fpga/test_rfcn_api.cpp)
target_link_libraries(test-rfcn-api paddle-mobile)
ADD_EXECUTABLE(test-marker2 fpga/test_marker2.cpp test_helper.h test_include.h executor_for_test.h )
target_link_libraries(test-marker2 paddle-mobile)
set(FOUND_MATCH ON)
endif ()
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
#ifdef PADDLE_MOBILE_FPGA_V1
#include "fpga/V1/api.h"
#endif
#ifdef PADDLE_MOBILE_FPGA_V2
#include "fpga/V2/api.h"
#endif
#include <string>
#ifdef COST_TIME_PRINT
#include <sys/time.h>
#include <time.h>
#include <iomanip>
#endif
void readStream(std::string filename, char *buf) {
std::ifstream in;
in.open(filename, std::ios::in | std::ios::binary);
if (!in.is_open()) {
std::cout << "open File Failed." << std::endl;
return;
}
in.seekg(0, std::ios::end); // go to the end
auto length = in.tellg(); // report location (this is the length)
in.seekg(0, std::ios::beg); // go back to the beginning
in.read(buf, length);
DLOG << length;
in.close();
}
void convert_to_chw(int16_t **data_in, int channel, int height, int width,
int num, int16_t *data_tmp) {
int64_t amount_per_side = width * height;
for (int n = 0; n < num; n++) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
for (int c = 0; c < channel; c++) {
*(data_tmp + n * amount_per_side * channel + c * amount_per_side +
width * h + w) = *((*data_in)++);
}
}
}
}
}
void dump_stride_half(std::string filename, Tensor input_tensor,
const int dumpnum, bool use_chw) {
// bool use_chw = true;
if (input_tensor.dims().size() != 4) return;
int c = (input_tensor.dims())[1];
int h = (input_tensor.dims())[2];
int w = (input_tensor.dims())[3];
int n = (input_tensor.dims())[0];
auto data_ptr = input_tensor.get_data();
auto *data_ptr_16 = reinterpret_cast<half *>(data_ptr);
auto data_tmp = data_ptr_16;
if (use_chw) {
data_tmp =
reinterpret_cast<half *>(malloc(n * c * h * w * sizeof(int16_t)));
convert_to_chw(&data_ptr_16, c, h, w, n, data_tmp);
}
std::ofstream out(filename.c_str());
float result = 0;
int stride = input_tensor.numel() / dumpnum;
stride = stride > 0 ? stride : 1;
for (int i = 0; i < input_tensor.numel(); i += stride) {
result = paddle_mobile::fpga::fp16_2_fp32(data_tmp[i]);
out << result << std::endl;
}
out.close();
if (data_tmp != data_ptr_16) {
free(data_tmp);
}
}
void dump_stride_float(std::string filename, Tensor input_tensor,
const int dumpnum) {
auto data_ptr = reinterpret_cast<float *>(input_tensor.get_data());
std::ofstream out(filename.c_str());
float result = 0;
int stride = input_tensor.numel() / dumpnum;
stride = stride > 0 ? stride : 1;
for (int i = 0; i < input_tensor.numel(); i += stride) {
result = data_ptr[i];
out << result << std::endl;
}
out.close();
}
void dump_stride(std::string filename, Tensor input_tensor, const int dumpnum,
bool use_chw) {
static int i = 0;
if (input_tensor.numel() == 0) {
return;
}
if (input_tensor.type() == typeid(float)) {
DLOG << "op: " << i++ << ", float data " << input_tensor.numel();
dump_stride_float(filename, input_tensor, dumpnum);
} else {
DLOG << "op: " << i++ << ", half data " << input_tensor.numel();
dump_stride_half(filename, input_tensor, dumpnum, use_chw);
}
DLOG << "dump input address: " << input_tensor.get_data();
}
static const char *g_marker_combine = "../models/marker/marker_2segment";
// static const char *g_marker_combine = "../models/marker/model2";
static const char *g_image_src_float =
"../models/marker/marker_2segment/marker_2.bin";
// static const char *g_image_src_float = "../models/marker/model2/data.bin";
int main() {
paddle_mobile::fpga::open_device();
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile;
if (paddle_mobile.Load(std::string(g_marker_combine) + "/model",
std::string(g_marker_combine) + "/params", true, false,
1, true)) {
// if (paddle_mobile.Load(std::string(g_marker_combine), true)) {
float img_info[3] = {432, 1280, 1.0f};
auto img = reinterpret_cast<float *>(
fpga::fpga_malloc(144 * 14 * 14 * sizeof(float)));
readStream(g_image_src_float, reinterpret_cast<char *>(img));
std::vector<void *> v(3, nullptr);
paddle_mobile.FeedData({img});
// paddle_mobile.Predict_To(-1);
#ifdef COST_TIME_PRINT
timeval start11, end11;
long dif_sec, dif_usec; // NOLINT
#endif
#ifdef COST_TIME_PRINT
gettimeofday(&start11, NULL);
#endif
paddle_mobile.Predict_To(-1);
#ifdef COST_TIME_PRINT
gettimeofday(&end11, NULL);
dif_sec = end11.tv_sec - start11.tv_sec;
dif_usec = end11.tv_usec - start11.tv_usec;
std::cout << "total: "
<< " cost time: " << (dif_sec * 1000000 + dif_usec) << " us"
<< std::endl;
#endif
for (int i = 0; i < 8; i++) {
auto tensor_ptr = paddle_mobile.FetchResult(i);
std::string saveName = "marker_" + std::to_string(i);
// if(i != 58)
paddle_mobile::fpga::fpga_invalidate((*tensor_ptr).get_data(),
tensor_ptr->numel() * sizeof(float));
// tensor_ptr->numel() * sizeof(float));
dump_stride(saveName, (*tensor_ptr), tensor_ptr->numel(),
true); // 20);//tensor_ptr->numel());
}
// paddle_mobile.GetResults(&v);
DLOG << "Computation done";
fpga::fpga_free(img);
}
return 0;
}
......@@ -138,6 +138,9 @@ if (CON GREATER -1)
set(CONV_TRANSPOSE_OP ON)
set(FUSION_DECONVADDBNRELU_OP ON)
set(FUSION_DECONVADDBN_OP ON)
set(ELEMENTWISEMUL_OP ON)
set(FUSION_FCRELU_OP ON)
set(RELU_OP ON)
set(FOUND_MATCH ON)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册