提交 6f8e72cc 编写于 作者: Z Zhen Wang

Merge branch 'develop' of https://github.com/PaddlePaddle/paddle-mobile into optimize-gemm_int8

...@@ -2,9 +2,9 @@ cmake_minimum_required(VERSION 3.0) ...@@ -2,9 +2,9 @@ cmake_minimum_required(VERSION 3.0)
project(paddle-mobile) project(paddle-mobile)
# select the platform to build # select the platform to build
option(CPU "armv7 with neon support" ON) option(CPU "armv7 with neon support" OFF)
option(MALI_GPU "mali gpu support" OFF) option(MALI_GPU "mali gpu support" OFF)
option(FPGA "fpga support" OFF) option(FPGA "fpga support" ON)
option(USE_OPENMP "openmp support" OFF) option(USE_OPENMP "openmp support" OFF)
option(DEBUGING "enable debug mode" ON) option(DEBUGING "enable debug mode" ON)
...@@ -29,7 +29,10 @@ if(DEBUGING) ...@@ -29,7 +29,10 @@ if(DEBUGING)
message(STATUS "debugging mode") message(STATUS "debugging mode")
add_definitions(-DPADDLE_MOBILE_DEBUG) add_definitions(-DPADDLE_MOBILE_DEBUG)
else() else()
add_definitions(-fvisibility=hidden -fvisibility-inlines-hidden) if(FPGA)
else()
add_definitions(-fvisibility=hidden -fvisibility-inlines-hidden)
endif()
endif() endif()
if(USE_EXCEPTION) if(USE_EXCEPTION)
...@@ -93,8 +96,7 @@ else() ...@@ -93,8 +96,7 @@ else()
endif() endif()
if(FPGA) if(FPGA)
set(DEBUGING ON) message("FPGA mode enabled")
add_definitions(-DPADDLE_MOBILE_DEBUG)
add_definitions(-DPADDLE_MOBILE_FPGA) add_definitions(-DPADDLE_MOBILE_FPGA)
else() else()
file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/*.cpp src/operators/kernel/fpga/*.cc) file(GLOB_RECURSE _tmp_list src/operators/kernel/fpga/*.cpp src/operators/kernel/fpga/*.cc)
...@@ -177,6 +179,10 @@ if(DEBUGING) ...@@ -177,6 +179,10 @@ if(DEBUGING)
else() else()
add_subdirectory(test) add_subdirectory(test)
endif() endif()
elseif(FPGA)
add_subdirectory(test)
endif() endif()
...@@ -34,6 +34,7 @@ const char *G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add"; ...@@ -34,6 +34,7 @@ const char *G_OP_TYPE_FUSION_CONV_ADD = "fusion_conv_add";
const char *G_OP_TYPE_LRN = "lrn"; const char *G_OP_TYPE_LRN = "lrn";
const char *G_OP_TYPE_MUL = "mul"; const char *G_OP_TYPE_MUL = "mul";
const char *G_OP_TYPE_MULTICLASS_NMS = "multiclass_nms"; const char *G_OP_TYPE_MULTICLASS_NMS = "multiclass_nms";
const char *G_OP_TYPE_POLYGON_BOX_TRANSFORM = "polygon_box_transform";
const char *G_OP_TYPE_POOL2D = "pool2d"; const char *G_OP_TYPE_POOL2D = "pool2d";
const char *G_OP_TYPE_PRIOR_BOX = "prior_box"; const char *G_OP_TYPE_PRIOR_BOX = "prior_box";
const char *G_OP_TYPE_RELU = "relu"; const char *G_OP_TYPE_RELU = "relu";
...@@ -94,6 +95,7 @@ std::unordered_map< ...@@ -94,6 +95,7 @@ std::unordered_map<
{G_OP_TYPE_FUSION_CONV_BN_ADD_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_BN_ADD_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_PRIOR_BOX, {{"Image", "Input"}, {"Boxes", "Variances"}}}, {G_OP_TYPE_PRIOR_BOX, {{"Image", "Input"}, {"Boxes", "Variances"}}},
{G_OP_TYPE_MULTICLASS_NMS, {{"BBoxes", "Scores"}, {"Out"}}}, {G_OP_TYPE_MULTICLASS_NMS, {{"BBoxes", "Scores"}, {"Out"}}},
{G_OP_TYPE_POLYGON_BOX_TRANSFORM, {{"Input"}, {"Output"}}},
{G_OP_TYPE_FC, {{"X", "Y", "Z"}, {"Out"}}}, {G_OP_TYPE_FC, {{"X", "Y", "Z"}, {"Out"}}},
{G_OP_TYPE_RESHAPE, {{"X"}, {"Out"}}}, {G_OP_TYPE_RESHAPE, {{"X"}, {"Out"}}},
{G_OP_TYPE_DEPTHWISE_CONV, {{"Input"}, {"Output"}}}, {G_OP_TYPE_DEPTHWISE_CONV, {{"Input"}, {"Output"}}},
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cstdlib> #include <cstdlib>
#include <cstring>
#include "common/enforce.h" #include "common/enforce.h"
#include "common/log.h" #include "common/log.h"
......
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
#include "fpga/filter.h" #include "fpga/filter.h"
#include "fpga/image.h" #include "fpga/image.h"
#define FPGA_TEST_MODE #define FPGA_TEST_MODE
//#define PADDLE_MOBILE_OS_LINUX // #define PADDLE_MOBILE_OS_LINUX
namespace paddle_mobile { namespace paddle_mobile {
namespace fpga { namespace fpga {
...@@ -149,7 +149,7 @@ int ComputeBasicConv(const struct ConvArgs &args) { ...@@ -149,7 +149,7 @@ int ComputeBasicConv(const struct ConvArgs &args) {
return do_ioctl(IOCTL_CONFIG_CONV, &args); return do_ioctl(IOCTL_CONFIG_CONV, &args);
} }
int ComputeFpgaConv(const struct WrapperConvArgs &args) { int ComputeFpgaConv(const struct SplitConvArgs &args) {
#ifdef FPGA_TEST_MODE #ifdef FPGA_TEST_MODE
DLOG << "=============ComputeFPGAConv==========="; DLOG << "=============ComputeFPGAConv===========";
DLOG << " filter_num:" << args.filter_num DLOG << " filter_num:" << args.filter_num
...@@ -194,8 +194,8 @@ int ComputeFpgaEWAdd(const struct EWAddArgs &args) { ...@@ -194,8 +194,8 @@ int ComputeFpgaEWAdd(const struct EWAddArgs &args) {
#ifdef FPGA_TEST_MODE #ifdef FPGA_TEST_MODE
DLOG << "=============ComputeFpgaEWAdd==========="; DLOG << "=============ComputeFpgaEWAdd===========";
DLOG << " relu_enabled:" << args.relu_enabled DLOG << " relu_enabled:" << args.relu_enabled
<< " const0:" << fp16_2_fp32(short(args.const0)) << " const0:" << fp16_2_fp32(int16_t(args.const0))
<< " const1:" << fp16_2_fp32(short(args.const1)); << " const1:" << fp16_2_fp32(int16_t(args.const1));
DLOG << " image0_address:" << args.image0.address DLOG << " image0_address:" << args.image0.address
<< " image0_scale_address:" << args.image0.scale_address << " image0_scale_address:" << args.image0.scale_address
<< " image0_channels:" << args.image0.channels << " image0_channels:" << args.image0.channels
...@@ -383,10 +383,10 @@ void format_concat_output(framework::Tensor *out, int height, int width, ...@@ -383,10 +383,10 @@ void format_concat_output(framework::Tensor *out, int height, int width,
out->reset_data_ptr(data_ptr); out->reset_data_ptr(data_ptr);
} }
void fill_conv_arg(struct WrapperConvArgs *arg, framework::Tensor *input, void fill_split_arg(struct SplitConvArgs *arg, framework::Tensor *input,
framework::Tensor *out, framework::Tensor *filter, framework::Tensor *out, framework::Tensor *filter,
bool relu_enabled, int group_num, int stride_h, int stride_w, bool relu_enabled, int group_num, int stride_h,
int padding_h, int padding_w, float *bs_ptr) { int stride_w, int padding_h, int padding_w, float *bs_ptr) {
auto input_ptr = input->data<float>(); auto input_ptr = input->data<float>();
auto filter_ptr = filter->data<float>(); auto filter_ptr = filter->data<float>();
auto out_ptr = out->data<float>(); auto out_ptr = out->data<float>();
......
...@@ -89,7 +89,7 @@ struct ConcatArgs { ...@@ -89,7 +89,7 @@ struct ConcatArgs {
uint32_t width; uint32_t width;
}; };
struct WrapperConvArgs { struct SplitConvArgs {
uint32_t split_num; uint32_t split_num;
uint32_t group_num; uint32_t group_num;
uint32_t filter_num; uint32_t filter_num;
...@@ -98,6 +98,14 @@ struct WrapperConvArgs { ...@@ -98,6 +98,14 @@ struct WrapperConvArgs {
struct ConcatArgs concat_arg; struct ConcatArgs concat_arg;
}; };
struct GroupConvArgs {
uint32_t group_num;
uint32_t filter_num;
struct ImageOutputArgs output;
struct SplitConvArgs* conv_args;
struct ConcatArgs concat_arg;
};
struct PoolingArgs { struct PoolingArgs {
int16_t mode; // mode: 0:max, 1:avg int16_t mode; // mode: 0:max, 1:avg
half kernel_reciprocal; half kernel_reciprocal;
...@@ -159,30 +167,6 @@ struct MemoryCacheArgs { ...@@ -159,30 +167,6 @@ struct MemoryCacheArgs {
#define IOCTL_FPGA_REG_READ _IOW(IOCTL_FPGA_MAGIC, 28, struct FpgaRegReadArgs) #define IOCTL_FPGA_REG_READ _IOW(IOCTL_FPGA_MAGIC, 28, struct FpgaRegReadArgs)
#define IOCTL_FPGA_REG_WRITE _IOW(IOCTL_FPGA_MAGIC, 29, struct FpgaRegWriteArgs) #define IOCTL_FPGA_REG_WRITE _IOW(IOCTL_FPGA_MAGIC, 29, struct FpgaRegWriteArgs)
enum FPGA_ERR_TYPE {
ERR_IOCTL_CMD = -1,
ERR_TIMEOUT = -2,
ERR_COMPLETION_TIMEOUT = -3,
ERR_INVALID_FPGA_ADDR = -4,
ERR_NOMEM = -5,
ERR_NO_RESERVE_MEM = -6,
ERR_COPY_FROM_USER = -7,
ERR_COPY_TO_USER = -8,
ERR_DEL_TIMER = -9,
ERR_ENABLE_MSI = -10,
ERR_REGISTER_IRQ = -11,
ERR_PCIE_REGISTER = -12,
ERR_PCIE_PROBE = -13,
ERR_REGISTER_BLOCK = -14,
ERR_ALLOC_GENDISK = -15,
ERR_INIT_QUEUE = -16,
ERR_WAIT = -17,
ERR_ECC_ERROR = -31,
ERR_FPGA_FAIL_STOP = -64,
ERR_FPGA_DEBUG_STOP = -113,
DEV_TMP_UNAVAILABLE = -128
};
//============================== API ============================= //============================== API =============================
int open_device(); int open_device();
...@@ -195,7 +179,7 @@ int fpga_flush(void* address, size_t size); ...@@ -195,7 +179,7 @@ int fpga_flush(void* address, size_t size);
int fpga_invalidate(void* address, size_t size); int fpga_invalidate(void* address, size_t size);
int PerformBypass(const struct BypassArgs& args); int PerformBypass(const struct BypassArgs& args);
int ComputeFpgaConv(const struct WrapperConvArgs& args); int ComputeFpgaConv(const struct SplitConvArgs& args);
int ComputeFpgaPool(const struct PoolingArgs& args); int ComputeFpgaPool(const struct PoolingArgs& args);
int ComputeFpgaEWAdd(const struct EWAddArgs& args); int ComputeFpgaEWAdd(const struct EWAddArgs& args);
int ComputeFPGAConcat(const struct ConcatArgs& args); int ComputeFPGAConcat(const struct ConcatArgs& args);
...@@ -220,10 +204,10 @@ void format_bias_scale_array(float** bias_scale_array, ...@@ -220,10 +204,10 @@ void format_bias_scale_array(float** bias_scale_array,
void format_concat_output(framework::Tensor* out, int height, int width, void format_concat_output(framework::Tensor* out, int height, int width,
int image_num, uint32_t* channel_num); int image_num, uint32_t* channel_num);
void fill_conv_arg(struct WrapperConvArgs* arg, framework::Tensor* input, void fill_split_arg(struct SplitConvArgs* arg, framework::Tensor* input,
framework::Tensor* out, framework::Tensor* filter, framework::Tensor* out, framework::Tensor* filter,
bool relu_enabled, int group_num, int stride_h, int stride_w, bool relu_enabled, int group_num, int stride_h,
int padding_h, int padding_w, float* bs_ptr); int stride_w, int padding_h, int padding_w, float* bs_ptr);
half fp32_2_fp16(float fp32_num); half fp32_2_fp16(float fp32_num);
float fp16_2_fp32(half fp16_num); float fp16_2_fp32(half fp16_num);
......
...@@ -21,7 +21,10 @@ namespace paddle_mobile { ...@@ -21,7 +21,10 @@ namespace paddle_mobile {
namespace fpga { namespace fpga {
namespace filter { namespace filter {
int calc_division_capacity(int chw) { return 2048 / ((chw + 15) / 16) * 32; } int calc_division_capacity(int chw) {
int n = 2048 / ((chw + 15) / 16) * 32;
return n < 2048 ? n : 2048;
}
int calc_split_num(int num, int division_capacity) { int calc_split_num(int num, int division_capacity) {
return (num + division_capacity - 1) / division_capacity; return (num + division_capacity - 1) / division_capacity;
......
...@@ -199,6 +199,9 @@ LOAD_OP3(pool2d, CPU, MALI_GPU, FPGA); ...@@ -199,6 +199,9 @@ LOAD_OP3(pool2d, CPU, MALI_GPU, FPGA);
#ifdef MULTICLASSNMS_OP #ifdef MULTICLASSNMS_OP
LOAD_OP1(multiclass_nms, CPU); LOAD_OP1(multiclass_nms, CPU);
#endif #endif
#ifdef POLYGONBOXTRANSFORM_OP
LOAD_OP1(polygon_box_transform, CPU);
#endif
#ifdef SUM_OP #ifdef SUM_OP
LOAD_OP1(sum, CPU); LOAD_OP1(sum, CPU);
#endif #endif
......
...@@ -12,57 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,57 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef MUL_OP #ifdef POLYGONBOXTRANSFORM_OP
#include "operators/kernel/mul_kernel.h" #include "operators/kernel/polygon_box_transform_kernel.h"
#include "operators/kernel/central-arm-func/polygon_box_transform_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool MulKernel<FPGA, float>::Init(MulParam<FPGA> *param) { bool PolygonBoxTransformKernel<CPU, float>::Init(
bool relu_enabled = false; PolygonBoxTransformParam<CPU> *param) {
auto input_x = const_cast<LoDTensor *>(param->InputX());
auto filter = const_cast<LoDTensor *>(param->InputY());
auto out = param->Out();
PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == filter->dims()[0],
"Image channel should be equal to weight number");
int channel = (uint32_t)out->dims()[1];
auto bs_ptr =
(float *)fpga::fpga_malloc(2 * channel * sizeof(float)); // NOLINT
for (int i = 0; i < channel; i++) {
bs_ptr[i + channel] = 1;
bs_ptr[i] = 0;
}
int num = (uint32_t)filter->dims()[1];
int chw = (uint32_t)filter->dims()[0];
PADDLE_MOBILE_ENFORCE(
chw == input_x->numel(),
"Filter element num should be equal to IFM element num");
int height = (uint32_t)input_x->dims()[2];
int width = (uint32_t)input_x->dims()[3];
int filter_channel = chw / height / width;
out->Resize(framework::make_ddim({1, channel, 1, 1}));
filter->Resize(framework::make_ddim({num, filter_channel, height, width}));
float max_value = fpga::filter_find_max(filter);
fpga::format_fc_filter(filter, max_value);
int element_num_per_div = fpga::get_filter_num_per_div(filter, 1);
fpga::format_bias_scale_array(&bs_ptr, element_num_per_div, channel);
fpga::format_fp16_ofm(out);
fpga::WrapperConvArgs conv_arg = {0};
fpga::fill_conv_arg(&conv_arg, input_x, out, filter, relu_enabled, 1, 1, 1, 0,
0, bs_ptr);
param->SetFpgaArgs(conv_arg);
return true; return true;
} }
template <> template <>
void MulKernel<FPGA, float>::Compute(const MulParam<FPGA> &param) const { void PolygonBoxTransformKernel<CPU, float>::Compute(
fpga::ComputeFpgaConv(param.FpgaArgs()); const PolygonBoxTransformParam<CPU> &param) const {
PolygonBoxTransformCompute<float>(param);
} }
} // namespace operators } // namespace operators
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef POLYGONBOXTRANSFORM_OP
#pragma once
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename P>
void PolygonBoxTransformCompute(const PolygonBoxTransformParam<CPU>& param) {
const auto* input = param.Input();
const auto& input_dims = input->dims();
const auto* input_data = input->data<float>();
auto* output = param.Output();
auto* output_data = output->mutable_data<float>();
int64_t batch_size = input_dims[0];
int64_t geo_channel = input_dims[1];
int64_t height = input_dims[2];
int64_t width = input_dims[3];
int64_t id = 0;
for (int64_t id_n = 0; id_n < batch_size * geo_channel; ++id_n) {
for (int64_t id_h = 0; id_h < height; ++id_h) {
for (int64_t id_w = 0; id_w < width; ++id_w) {
id = id_n * height * width + width * id_h + id_w;
if (id_n % 2 == 0) {
output_data[id] = id_w * 4 - input_data[id];
} else {
output_data[id] = id_h * 4 - input_data[id];
}
}
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -66,10 +66,11 @@ bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam<FPGA> *param) { ...@@ -66,10 +66,11 @@ bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam<FPGA> *param) {
fpga::format_bias_scale_array(&bs_ptr, element_num_per_div, channel); fpga::format_bias_scale_array(&bs_ptr, element_num_per_div, channel);
fpga::format_fp16_ofm(out); fpga::format_fp16_ofm(out);
fpga::WrapperConvArgs conv_arg = {0}; fpga::SplitConvArgs conv_arg = {0};
fpga::fill_conv_arg(&conv_arg, input, out, filter, relu_enabled, fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
param->Groups(), param->Strides()[0], param->Strides()[1], param->Groups(), param->Strides()[0],
param->Paddings()[0], param->Paddings()[1], bs_ptr); param->Strides()[1], param->Paddings()[0],
param->Paddings()[1], bs_ptr);
param->SetFpgaArgs(conv_arg); param->SetFpgaArgs(conv_arg);
return true; return true;
......
...@@ -65,10 +65,11 @@ bool ConvAddBNReluKernel<FPGA, float>::Init( ...@@ -65,10 +65,11 @@ bool ConvAddBNReluKernel<FPGA, float>::Init(
fpga::format_fp16_ofm(out); fpga::format_fp16_ofm(out);
fpga::WrapperConvArgs conv_arg = {0}; fpga::SplitConvArgs conv_arg = {0};
fpga::fill_conv_arg(&conv_arg, input, out, filter, relu_enabled, fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
param->Groups(), param->Strides()[0], param->Strides()[1], param->Groups(), param->Strides()[0],
param->Paddings()[0], param->Paddings()[1], bs_ptr); param->Strides()[1], param->Paddings()[0],
param->Paddings()[1], bs_ptr);
param->SetFpgaArgs(conv_arg); param->SetFpgaArgs(conv_arg);
return true; return true;
} }
......
...@@ -47,10 +47,11 @@ bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam<FPGA> *param) { ...@@ -47,10 +47,11 @@ bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam<FPGA> *param) {
fpga::format_fp16_ofm(out); fpga::format_fp16_ofm(out);
fpga::WrapperConvArgs conv_arg = {0}; fpga::SplitConvArgs conv_arg = {0};
fpga::fill_conv_arg(&conv_arg, input, out, filter, relu_enabled, fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
param->Groups(), param->Strides()[0], param->Strides()[1], param->Groups(), param->Strides()[0],
param->Paddings()[0], param->Paddings()[1], bs_ptr); param->Strides()[1], param->Paddings()[0],
param->Paddings()[1], bs_ptr);
param->SetFpgaArgs(conv_arg); param->SetFpgaArgs(conv_arg);
return true; return true;
} }
......
...@@ -59,10 +59,11 @@ bool ConvBNKernel<FPGA, float>::Init(FusionConvBNParam<FPGA> *param) { ...@@ -59,10 +59,11 @@ bool ConvBNKernel<FPGA, float>::Init(FusionConvBNParam<FPGA> *param) {
fpga::format_fp16_ofm(out); fpga::format_fp16_ofm(out);
fpga::WrapperConvArgs conv_arg = {0}; fpga::SplitConvArgs conv_arg = {0};
fpga::fill_conv_arg(&conv_arg, input, out, filter, relu_enabled, fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
param->Groups(), param->Strides()[0], param->Strides()[1], param->Groups(), param->Strides()[0],
param->Paddings()[0], param->Paddings()[1], bs_ptr); param->Strides()[1], param->Paddings()[0],
param->Paddings()[1], bs_ptr);
param->SetFpgaArgs(conv_arg); param->SetFpgaArgs(conv_arg);
return true; return true;
} }
......
...@@ -59,10 +59,11 @@ bool ConvBNReluKernel<FPGA, float>::Init(FusionConvBNReluParam<FPGA> *param) { ...@@ -59,10 +59,11 @@ bool ConvBNReluKernel<FPGA, float>::Init(FusionConvBNReluParam<FPGA> *param) {
fpga::format_fp16_ofm(out); fpga::format_fp16_ofm(out);
fpga::WrapperConvArgs conv_arg = {0}; fpga::SplitConvArgs conv_arg = {0};
fpga::fill_conv_arg(&conv_arg, input, out, filter, relu_enabled, fpga::fill_split_arg(&conv_arg, input, out, filter, relu_enabled,
param->Groups(), param->Strides()[0], param->Strides()[1], param->Groups(), param->Strides()[0],
param->Paddings()[0], param->Paddings()[1], bs_ptr); param->Strides()[1], param->Paddings()[0],
param->Paddings()[1], bs_ptr);
param->SetFpgaArgs(conv_arg); param->SetFpgaArgs(conv_arg);
return true; return true;
} }
......
...@@ -53,9 +53,9 @@ bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam<FPGA> *param) { ...@@ -53,9 +53,9 @@ bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam<FPGA> *param) {
fpga::format_bias_scale_array(&bs_ptr, element_num_per_div, channel); fpga::format_bias_scale_array(&bs_ptr, element_num_per_div, channel);
fpga::format_fp16_ofm(out); fpga::format_fp16_ofm(out);
fpga::WrapperConvArgs conv_arg = {0}; fpga::SplitConvArgs conv_arg = {0};
fpga::fill_conv_arg(&conv_arg, input_x, out, filter, relu_enabled, 1, 1, 1, 0, fpga::fill_split_arg(&conv_arg, input_x, out, filter, relu_enabled, 1, 1, 1,
0, bs_ptr); 0, 0, bs_ptr);
param->SetFpgaArgs(conv_arg); param->SetFpgaArgs(conv_arg);
return true; return true;
} }
......
...@@ -54,9 +54,9 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam<FPGA> *param) { ...@@ -54,9 +54,9 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam<FPGA> *param) {
fpga::format_bias_scale_array(&bs_ptr, element_num_per_div, channel); fpga::format_bias_scale_array(&bs_ptr, element_num_per_div, channel);
fpga::format_fp16_ofm(out); fpga::format_fp16_ofm(out);
fpga::WrapperConvArgs conv_arg = {0}; fpga::SplitConvArgs conv_arg = {0};
fpga::fill_conv_arg(&conv_arg, input_x, out, filter, relu_enabled, 1, 1, 1, 0, fpga::fill_split_arg(&conv_arg, input_x, out, filter, relu_enabled, 1, 1, 1,
0, bs_ptr); 0, 0, bs_ptr);
param->SetFpgaArgs(conv_arg); param->SetFpgaArgs(conv_arg);
return true; return true;
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef POLYGONBOXTRANSFORM_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class PolygonBoxTransformKernel
: public framework::OpKernelBase<DeviceType,
PolygonBoxTransformParam<DeviceType>> {
public:
void Compute(const PolygonBoxTransformParam<DeviceType>& param) const;
bool Init(PolygonBoxTransformParam<DeviceType>* param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -546,11 +546,11 @@ class MulParam : OpParam { ...@@ -546,11 +546,11 @@ class MulParam : OpParam {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
private: private:
fpga::WrapperConvArgs fpga_conv_args; fpga::SplitConvArgs fpga_conv_args;
public: public:
const fpga::WrapperConvArgs &FpgaArgs() const { return fpga_conv_args; } const fpga::SplitConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::WrapperConvArgs &args) { fpga_conv_args = args; } void SetFpgaArgs(const fpga::SplitConvArgs &args) { fpga_conv_args = args; }
#endif #endif
}; };
#endif #endif
...@@ -999,6 +999,28 @@ class MultiClassNMSParam : public OpParam { ...@@ -999,6 +999,28 @@ class MultiClassNMSParam : public OpParam {
}; };
#endif #endif
#ifdef POLYGONBOXTRANSFORM_OP
template <typename Dtype>
class PolygonBoxTransformParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
PolygonBoxTransformParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_ = InputFrom<GType>(inputs, scope);
output_ = OutputFrom<GType>(outputs, scope);
}
const RType *Input() const { return input_; }
RType *Output() const { return output_; }
private:
RType *input_;
RType *output_;
};
#endif
template <typename Dtype> template <typename Dtype>
class FeedParam : public OpParam { class FeedParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
...@@ -1401,11 +1423,11 @@ class FusionFcParam : public OpParam { ...@@ -1401,11 +1423,11 @@ class FusionFcParam : public OpParam {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
private: private:
fpga::WrapperConvArgs fpga_conv_args; fpga::SplitConvArgs fpga_conv_args;
public: public:
const fpga::WrapperConvArgs &FpgaArgs() const { return fpga_conv_args; } const fpga::SplitConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::WrapperConvArgs &args) { fpga_conv_args = args; } void SetFpgaArgs(const fpga::SplitConvArgs &args) { fpga_conv_args = args; }
#endif #endif
}; };
...@@ -1441,11 +1463,11 @@ class FusionConvAddParam : public ConvParam<Dtype> { ...@@ -1441,11 +1463,11 @@ class FusionConvAddParam : public ConvParam<Dtype> {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
private: private:
fpga::WrapperConvArgs fpga_conv_args; fpga::SplitConvArgs fpga_conv_args;
public: public:
const fpga::WrapperConvArgs &FpgaArgs() const { return fpga_conv_args; } const fpga::SplitConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::WrapperConvArgs &args) { fpga_conv_args = args; } void SetFpgaArgs(const fpga::SplitConvArgs &args) { fpga_conv_args = args; }
#endif #endif
}; };
...@@ -1496,11 +1518,11 @@ class FusionConvAddPReluParam : public ConvParam<Dtype> { ...@@ -1496,11 +1518,11 @@ class FusionConvAddPReluParam : public ConvParam<Dtype> {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
private: private:
fpga::WrapperConvArgs fpga_conv_args; fpga::SplitConvArgs fpga_conv_args;
public: public:
const fpga::WrapperConvArgs &FpgaArgs() const { return fpga_conv_args; } const fpga::SplitConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::WrapperConvArgs &args) { fpga_conv_args = args; } void SetFpgaArgs(const fpga::SplitConvArgs &args) { fpga_conv_args = args; }
#endif #endif
}; };
#endif #endif
...@@ -1554,11 +1576,11 @@ class FusionConvAddAddPReluParam : public ConvParam<Dtype> { ...@@ -1554,11 +1576,11 @@ class FusionConvAddAddPReluParam : public ConvParam<Dtype> {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
private: private:
fpga::WrapperConvArgs fpga_conv_args; fpga::SplitConvArgs fpga_conv_args;
public: public:
const fpga::WrapperConvArgs &FpgaArgs() const { return fpga_conv_args; } const fpga::SplitConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::WrapperConvArgs &args) { fpga_conv_args = args; } void SetFpgaArgs(const fpga::SplitConvArgs &args) { fpga_conv_args = args; }
#endif #endif
}; };
#endif #endif
...@@ -1629,11 +1651,11 @@ class FusionConvAddBNReluParam : public ConvParam<Dtype> { ...@@ -1629,11 +1651,11 @@ class FusionConvAddBNReluParam : public ConvParam<Dtype> {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
private: private:
fpga::WrapperConvArgs fpga_conv_args; fpga::SplitConvArgs fpga_conv_args;
public: public:
const fpga::WrapperConvArgs &FpgaArgs() const { return fpga_conv_args; } const fpga::SplitConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::WrapperConvArgs &args) { fpga_conv_args = args; } void SetFpgaArgs(const fpga::SplitConvArgs &args) { fpga_conv_args = args; }
#endif #endif
}; };
#endif #endif
...@@ -1715,11 +1737,11 @@ class FusionConvBNAddReluParam : public ConvParam<Dtype> { ...@@ -1715,11 +1737,11 @@ class FusionConvBNAddReluParam : public ConvParam<Dtype> {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
private: private:
fpga::WrapperConvArgs fpga_conv_args; fpga::SplitConvArgs fpga_conv_args;
public: public:
const fpga::WrapperConvArgs &FpgaArgs() const { return fpga_conv_args; } const fpga::SplitConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::WrapperConvArgs &args) { fpga_conv_args = args; } void SetFpgaArgs(const fpga::SplitConvArgs &args) { fpga_conv_args = args; }
#endif #endif
}; };
#endif #endif
...@@ -1782,11 +1804,11 @@ class FusionConvBNParam : public ConvParam<Dtype> { ...@@ -1782,11 +1804,11 @@ class FusionConvBNParam : public ConvParam<Dtype> {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
private: private:
fpga::WrapperConvArgs fpga_conv_args; fpga::SplitConvArgs fpga_conv_args;
public: public:
const fpga::WrapperConvArgs &FpgaArgs() const { return fpga_conv_args; } const fpga::SplitConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::WrapperConvArgs &args) { fpga_conv_args = args; } void SetFpgaArgs(const fpga::SplitConvArgs &args) { fpga_conv_args = args; }
#endif #endif
}; };
#endif #endif
...@@ -1857,11 +1879,11 @@ class FusionConvAddBNParam : public ConvParam<Dtype> { ...@@ -1857,11 +1879,11 @@ class FusionConvAddBNParam : public ConvParam<Dtype> {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
private: private:
fpga::WrapperConvArgs fpga_conv_args; fpga::SplitConvArgs fpga_conv_args;
public: public:
const fpga::WrapperConvArgs &FpgaArgs() const { return fpga_conv_args; } const fpga::SplitConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::WrapperConvArgs &args) { fpga_conv_args = args; } void SetFpgaArgs(const fpga::SplitConvArgs &args) { fpga_conv_args = args; }
#endif #endif
}; };
#endif #endif
...@@ -1983,11 +2005,11 @@ class FusionConvBNReluParam : public ConvParam<Dtype> { ...@@ -1983,11 +2005,11 @@ class FusionConvBNReluParam : public ConvParam<Dtype> {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
private: private:
fpga::WrapperConvArgs fpga_conv_args; fpga::SplitConvArgs fpga_conv_args;
public: public:
const fpga::WrapperConvArgs &FpgaArgs() const { return fpga_conv_args; } const fpga::SplitConvArgs &FpgaArgs() const { return fpga_conv_args; }
void SetFpgaArgs(const fpga::WrapperConvArgs &args) { fpga_conv_args = args; } void SetFpgaArgs(const fpga::SplitConvArgs &args) { fpga_conv_args = args; }
#endif #endif
}; };
#endif #endif
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef POLYGONBOXTRANSFORM_OP
#include "operators/polygon_box_transform_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void PolygonBoxTransformOp<Dtype, T>::InferShape() const {
PADDLE_MOBILE_ENFORCE(this->param_.Input() != nullptr,
"Input (Input) of get_shape op should not be null.");
PADDLE_MOBILE_ENFORCE(this->param_.Output() != nullptr,
"Output (Output) of get_shape op should not be null.");
auto input_dims = this->param_.Input()->dims();
PADDLE_MOBILE_ENFORCE(input_dims.size() == 4, "input's rank must be 4.");
PADDLE_MOBILE_ENFORCE(input_dims[1] % 2 == 0,
"input's second dimension must be even.");
this->param_.Output()->Resize(input_dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(polygon_box_transform, ops::PolygonBoxTransformOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef POLYGONBOXTRANSFORM_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/polygon_box_transform_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class PolygonBoxTransformOp
: public framework::OperatorWithKernel<
DeviceType, PolygonBoxTransformParam<DeviceType>,
operators::PolygonBoxTransformKernel<DeviceType, T>> {
public:
PolygonBoxTransformOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, PolygonBoxTransformParam<DeviceType>,
operators::PolygonBoxTransformKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, PolygonBoxTransformParam<DeviceType>,
operators::PolygonBoxTransformKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -65,7 +65,6 @@ REGISTER_OPERATOR_CPU(sum, ops::SumOp); ...@@ -65,7 +65,6 @@ REGISTER_OPERATOR_CPU(sum, ops::SumOp);
REGISTER_OPERATOR_MALI_GPU(sum, ops::ConcatOp); REGISTER_OPERATOR_MALI_GPU(sum, ops::ConcatOp);
#endif #endif
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(sum, ops::ConcatOp);
#endif #endif
#endif #endif
...@@ -61,38 +61,11 @@ endif () ...@@ -61,38 +61,11 @@ endif ()
list(FIND NET "FPGAnets" CON) list(FIND NET "FPGAnets" CON)
if (CON GREATER -1) if (CON GREATER -1)
ADD_EXECUTABLE(test-resnet net/test_resnet.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-resnet paddle-mobile)
ADD_EXECUTABLE(test-resnet50 fpga/test_resnet50.cpp test_helper.h test_include.h executor_for_test.h) ADD_EXECUTABLE(test-resnet50 fpga/test_resnet50.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-resnet50 paddle-mobile) target_link_libraries(test-resnet50 paddle-mobile)
ADD_EXECUTABLE(test-fpga-EW fpga/test_fpga_EW.cpp test_helper.h test_include.h executor_for_test.h) # ADD_EXECUTABLE(test-resnet net/test_resnet.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-fpga-EW paddle-mobile) # target_link_libraries(test-resnet paddle-mobile)
ADD_EXECUTABLE(test-fpga-conv fpga/test_fpga_conv.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-fpga-conv paddle-mobile)
ADD_EXECUTABLE(test-fpga-pooling fpga/test_fpga_pooling.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-fpga-pooling paddle-mobile)
ADD_EXECUTABLE(test-fpga-bypass fpga/test_fpga_bypass.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-fpga-bypass paddle-mobile)
ADD_EXECUTABLE(test-fpga-softmax fpga/test_fpga_softmax.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-fpga-softmax paddle-mobile)
ADD_EXECUTABLE(test-fpga-concat fpga/test_fpga_concat.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-fpga-concat paddle-mobile)
ADD_EXECUTABLE(test-tensor-quant fpga/test_tensor_quant.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-tensor-quant paddle-mobile)
ADD_EXECUTABLE(test-fpga-concat-op fpga/test_concat_op.cpp test_helper.h test_include.h)
target_link_libraries(test-fpga-concat-op paddle-mobile)
ADD_EXECUTABLE(test-format-data fpga/test_format_data.cpp test_helper.h test_include.h)
target_link_libraries(test-format-data paddle-mobile)
set(FOUND_MATCH ON) set(FOUND_MATCH ON)
endif () endif ()
...@@ -208,6 +181,10 @@ if (NOT FOUND_MATCH) ...@@ -208,6 +181,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-multiclassnms-op operators/test_multiclass_nms_op.cpp test_helper.h test_include.h) ADD_EXECUTABLE(test-multiclassnms-op operators/test_multiclass_nms_op.cpp test_helper.h test_include.h)
target_link_libraries(test-multiclassnms-op paddle-mobile) target_link_libraries(test-multiclassnms-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-polygon-box-transform-op operators/test_polygon_box_transform_op.cpp test_helper.h test_include.h)
target_link_libraries(test-polygon-box-transform-op paddle-mobile)
# gen test # gen test
ADD_EXECUTABLE(test-reshape-op operators/test_reshape_op.cpp test_helper.h test_include.h) ADD_EXECUTABLE(test-reshape-op operators/test_reshape_op.cpp test_helper.h test_include.h)
target_link_libraries(test-reshape-op paddle-mobile) target_link_libraries(test-reshape-op paddle-mobile)
......
...@@ -30,7 +30,11 @@ int main() { ...@@ -30,7 +30,11 @@ int main() {
input_tensor.data<float>() + input_tensor.numel()); input_tensor.data<float>() + input_tensor.numel());
paddle_mobile.FeedData(input_tensor); paddle_mobile.FeedData(input_tensor);
paddle_mobile.Predict_To(-1); for (int i = 0; i < 1000; i++) {
paddle_mobile.Predict_To(-1);
if (i % 100 == 0) std::cout << i << std::endl;
}
// paddle_mobile.Predict_From(73); // paddle_mobile.Predict_From(73);
// paddle_mobile.Predict_From_To(72, 73); // paddle_mobile.Predict_From_To(72, 73);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_include.h"
#include "operators/polygon_box_transform_op.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype>
class TestPolygonBoxTransformOp {
public:
explicit TestPolygonBoxTransformOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
for (auto block_desc : blocks) {
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
for (auto op : ops) {
if (op->Type() == "polygon_box_transform") {
DLOG << " attr size: " << op->GetAttrMap().size();
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " input is : " << op->Input("Input")[0];
input_var_name = op->Input("Input")[0];
DLOG << " outputs size: " << op->GetOutputs().size();
DLOG << " output is : " << op->Output("Output")[0];
output_var_name = op->Output("Output")[0];
std::shared_ptr<operators::PolygonBoxTransformOp<Dtype, float>>
op_ptr = std::make_shared<
operators::PolygonBoxTransformOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(op_ptr);
return;
}
}
}
}
std::shared_ptr<Tensor> predict(const Tensor &t) {
auto scope = program_.scope;
Variable *input_feed_value = scope->Var(input_var_name);
auto tensor_input = input_feed_value->GetMutable<LoDTensor>();
tensor_input->ShareDataWith(t);
Variable *output = scope->Var(output_var_name);
auto *output_tensor = output->GetMutable<LoDTensor>();
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict(t, 0);
return out_tensor;
}
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
string input_var_name;
string output_var_name;
void predict(const Tensor &t, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
op->Run();
}
}
};
template class TestPolygonBoxTransformOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run PolygonBoxTransform Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_ocr));
paddle_mobile::framework::Tensor input;
SetupTensor<float>(&input, {1, 8, 1, 2}, static_cast<float>(0),
static_cast<float>(1));
auto *input_ptr = input.data<float>();
for (int i = 0; i < 16; ++i) {
*(input_ptr + i) = i;
}
DLOG << "input : ";
for (int i = 0; i < input.numel(); ++i) {
DLOG << " index " << i << " : " << input_ptr[i];
}
paddle_mobile::framework::TestPolygonBoxTransformOp<paddle_mobile::CPU>
testPolygonBoxTransformOp(program);
auto output = testPolygonBoxTransformOp.predict(input);
auto *output_ptr = output->data<float>();
DLOG << "output : ";
for (int i = 0; i < output->numel(); ++i) {
DLOG << " index " << i << " : " << output_ptr[i];
}
return 0;
}
...@@ -118,12 +118,9 @@ if (CON GREATER -1) ...@@ -118,12 +118,9 @@ if (CON GREATER -1)
set(POOL_OP ON) set(POOL_OP ON)
set(CONCAT_OP ON) set(CONCAT_OP ON)
set(SOFTMAX_OP ON) set(SOFTMAX_OP ON)
set(DROPOUT_OP ON)
set(FUSION_CONVBNRELU_OP ON) set(FUSION_CONVBNRELU_OP ON)
set(FUSION_CONVBN_OP ON) set(FUSION_CONVBN_OP ON)
set(FUSION_CONVADD_OP ON) set(FUSION_CONVADD_OP ON)
set(MUL_OP ON)
set(FOUND_MATCH ON) set(FOUND_MATCH ON)
endif() endif()
...@@ -198,6 +195,7 @@ if(NOT FOUND_MATCH) ...@@ -198,6 +195,7 @@ if(NOT FOUND_MATCH)
set(LRN_OP ON) set(LRN_OP ON)
set(MUL_OP ON) set(MUL_OP ON)
set(MULTICLASSNMS_OP ON) set(MULTICLASSNMS_OP ON)
set(POLYGONBOXTRANSFORM_OP ON)
set(POOL_OP ON) set(POOL_OP ON)
set(PRIORBOX_OP ON) set(PRIORBOX_OP ON)
set(RELU_OP ON) set(RELU_OP ON)
...@@ -239,6 +237,7 @@ endif() ...@@ -239,6 +237,7 @@ endif()
# option(LRN_OP "" ON) # option(LRN_OP "" ON)
# option(MUL_OP "" ON) # option(MUL_OP "" ON)
# option(MULTICLASSNMS_OP "" ON) # option(MULTICLASSNMS_OP "" ON)
# option(POLYGONBOXTRANSFORM_OP "" ON)
# option(POOL_OP "" ON) # option(POOL_OP "" ON)
# option(PRIORBOX_OP "" ON) # option(PRIORBOX_OP "" ON)
# option(RELU_OP "" ON) # option(RELU_OP "" ON)
...@@ -293,6 +292,9 @@ endif() ...@@ -293,6 +292,9 @@ endif()
if (MULTICLASSNMS_OP) if (MULTICLASSNMS_OP)
add_definitions(-DMULTICLASSNMS_OP) add_definitions(-DMULTICLASSNMS_OP)
endif() endif()
if (POLYGONBOXTRANSFORM_OP)
add_definitions(-DPOLYGONBOXTRANSFORM_OP)
endif()
if (POOL_OP) if (POOL_OP)
add_definitions(-DPOOL_OP) add_definitions(-DPOOL_OP)
endif() endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册