提交 1dc41116 编写于 作者: Z Zhen Wang

Merge branch 'add-int8-gemm' of https://github.com/wzzju/paddle-mobile into add-int8-gemm

......@@ -22,7 +22,7 @@ limitations under the License. */
#include "fpga/filter.h"
#include "fpga/image.h"
#define FPGA_TEST_MODE
#define PADDLE_MOBILE_OS_LINUX
//#define PADDLE_MOBILE_OS_LINUX
namespace paddle_mobile {
namespace fpga {
......@@ -125,6 +125,7 @@ float fp16_2_fp32(half fp16_num) {
}
int ComputeBasicConv(const struct ConvArgs &args) {
#ifdef FPGA_TEST_MODE
DLOG << "======Compute Basic Conv======";
DLOG << " relu_enabled:" << args.relu_enabled
<< " sb_address:" << args.sb_address
......@@ -144,7 +145,7 @@ int ComputeBasicConv(const struct ConvArgs &args) {
<< " stride_w:" << args.kernel.stride_w;
DLOG << " out_address:" << args.output.address
<< " out_scale_address:" << args.output.scale_address;
#endif
return do_ioctl(IOCTL_CONFIG_CONV, &args);
}
......@@ -192,8 +193,9 @@ int ComputeFpgaPool(const struct PoolingArgs &args) {
int ComputeFpgaEWAdd(const struct EWAddArgs &args) {
#ifdef FPGA_TEST_MODE
DLOG << "=============ComputeFpgaEWAdd===========";
DLOG << " relu_enabled:" << args.relu_enabled << " const0:" << args.const0
<< " const1:" << args.const1;
DLOG << " relu_enabled:" << args.relu_enabled
<< " const0:" << fp16_2_fp32(short(args.const0))
<< " const1:" << fp16_2_fp32(short(args.const1));
DLOG << " image0_address:" << args.image0.address
<< " image0_scale_address:" << args.image0.scale_address
<< " image0_channels:" << args.image0.channels
......@@ -401,8 +403,8 @@ void fill_conv_arg(struct WrapperConvArgs *arg, framework::Tensor *input,
arg->concat_arg.image_num = arg->split_num;
arg->concat_arg.image_out = out_ptr;
arg->concat_arg.scale_out = out->scale;
arg->concat_arg.height = (uint32_t)filter->dims()[2];
arg->concat_arg.width = (uint32_t)filter->dims()[3];
arg->concat_arg.height = (uint32_t)out->dims()[2];
arg->concat_arg.width = (uint32_t)out->dims()[3];
int n = arg->split_num;
arg->concat_arg.images_in =
......@@ -411,7 +413,6 @@ void fill_conv_arg(struct WrapperConvArgs *arg, framework::Tensor *input,
(float **)fpga_malloc(n * sizeof(float *)); // NOLINT
arg->concat_arg.channel_num =
(uint32_t *)fpga_malloc(n * sizeof(uint32_t)); // NOLINT
arg->concat_arg.image_out = out_ptr;
auto channel = (int)out->dims()[1]; // NOLINT
int filter_num_per_div = get_filter_num_per_div(filter, group_num);
......
......@@ -27,6 +27,9 @@ void align_element(float **data_in, int num_per_div_before_alignment, int num) {
(num + num_per_div_before_alignment - 1) / num_per_div_before_alignment;
int num_per_div_after_alignment =
align_to_x(num_per_div_before_alignment, BS_NUM_ALIGNMENT);
if (num_per_div_before_alignment == num_per_div_after_alignment) {
return;
}
int num_element =
2 * div_num * num_per_div_after_alignment; // including bias & scale
float *ptr_aligned =
......
......@@ -210,12 +210,12 @@ void format_filter(float **data_in, int num, int channel, int height, int width,
align_to_x(num_per_div_before_alignment, FILTER_NUM_ALIGNMENT);
int div_num =
(num + num_per_div_before_alignment - 1) / num_per_div_before_alignment;
int num_after_alignment = num_per_div_after_alignment * div_num;
int residual = num % num_per_div_before_alignment;
int num_after_alignment = num_per_div_after_alignment *
((residual == 0) ? div_num : (div_num - 1)) +
align_to_x(residual, FILTER_NUM_ALIGNMENT);
quantize(data_in, data_size, max);
char **quantize_data = (char **)data_in; // NOLINT
convert_to_hwc(quantize_data, num, channel, height, width);
align_element(quantize_data, num, chw);
align_num(quantize_data, num_per_div_before_alignment, num, chw);
......
......@@ -199,6 +199,12 @@ LOAD_OP3(pool2d, CPU, MALI_GPU, FPGA);
#ifdef MULTICLASSNMS_OP
LOAD_OP1(multiclass_nms, CPU);
#endif
#ifdef SUM_OP
LOAD_OP1(sum, CPU);
#endif
#ifdef ELEMENTWISEMUL_OP
LOAD_OP1(elementwise_mul, CPU);
#endif
#ifdef SLICE_OP
LOAD_OP2(slice, CPU, MALI_GPU);
#endif
......@@ -206,5 +212,8 @@ LOAD_OP2(slice, CPU, MALI_GPU);
LOAD_OP2(fusion_conv_bn, CPU, FPGA);
LOAD_FUSION_MATCHER(fusion_conv_bn);
#endif
#ifdef ELEMENTWISESUB_OP
LOAD_OP1(elementwise_sub, CPU)
#endif
LOAD_OP1(quantize, CPU);
LOAD_OP1(dequantize, CPU);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISESUB_OP
#include "operators/elementwise_sub_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void ElementwiseSubOp<Dtype, T>::InferShape() const {
auto x_dim = this->param_.InputX()->dims();
this->param_.Out()->Resize(x_dim);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(elementwise_sub, ops::ElementwiseSubOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
REGISTER_OPERATOR_MALI_GPU(elementwise_sub, ops::ElementwiseSubOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISESUB_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "kernel/elementwise_sub_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class ElementwiseSubOp : public framework::OperatorWithKernel<
DeviceType, ElementwiseSubParam<DeviceType>,
operators::ElementwiseSubKernel<DeviceType, T>> {
public:
ElementwiseSubOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, ElementwiseSubParam<DeviceType>,
operators::ElementwiseSubKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, ElementwiseSubParam<DeviceType>,
operators::ElementwiseSubKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISESUB_OP
#include "operators/kernel/elementwise_sub_kernel.h"
#include "operators/kernel/central-arm-func/elementwise_sub_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ElementwiseSubKernel<CPU, float>::Init(ElementwiseSubParam<CPU> *param) {
return true;
}
template <>
void ElementwiseSubKernel<CPU, float>::Compute(
const ElementwiseSubParam<CPU> &param) const {
ElementwiseSubCompute<float>(param);
param.Out()->set_lod(param.InputX()->lod());
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISESUB_OP
#pragma once
#include "operators/math/elementwise_op_function.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename T>
struct SubFunctor {
inline T operator()(T a, T b) const { return a - b; }
};
template <typename P>
void ElementwiseSubCompute(const ElementwiseSubParam<CPU> &param) {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
Tensor *Out = param.Out();
Out->mutable_data<float>();
int axis = param.Axis();
ElementwiseComputeEx<SubFunctor<float>, float>(input_x, input_y, axis,
SubFunctor<float>(), Out);
}
template class ElementwiseSubKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -20,14 +20,12 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "framework/tensor.h"
#include "operators/math/poly_util.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
constexpr int kOutputDim = 6;
constexpr int kBBoxSize = 4;
template <class T>
bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
......@@ -90,6 +88,21 @@ static inline T JaccardOverlap(const T* box1, const T* box2,
}
}
template <class T>
static inline T PolyIoU(const T* box1, const T* box2, const size_t box_size,
const bool normalized) {
T bbox1_area = math::PolyArea<T>(box1, box_size, normalized);
T bbox2_area = math::PolyArea<T>(box2, box_size, normalized);
T inter_area = math::PolyOverlapArea<T>(box1, box2, box_size, normalized);
if (bbox1_area == 0 || bbox2_area == 0 || inter_area == 0) {
// If coordinate values are is invalid
// if area size <= 0, return 0.
return static_cast<T>(0.);
} else {
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <typename T>
static inline void NMSFast(const framework::Tensor& bbox,
const framework::Tensor& scores,
......@@ -116,8 +129,14 @@ static inline void NMSFast(const framework::Tensor& bbox,
for (size_t k = 0; k < selected_indices->size(); ++k) {
if (keep) {
const int kept_idx = (*selected_indices)[k];
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
T overlap = T(0.);
if (box_size == 4) {
overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, true);
} else {
overlap = PolyIoU<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, box_size, true);
}
keep = overlap <= adaptive_threshold;
} else {
break;
......@@ -190,6 +209,8 @@ void MultiClassOutput(const framework::Tensor& scores,
const std::map<int, std::vector<int>>& selected_indices,
framework::Tensor* outs) {
int predict_dim = scores.dims()[1];
int box_size = bboxes.dims()[1];
int out_dim = bboxes.dims()[1] + 2;
auto* scores_data = scores.data<T>();
auto* bboxes_data = bboxes.data<T>();
auto* odata = outs->data<T>();
......@@ -202,11 +223,11 @@ void MultiClassOutput(const framework::Tensor& scores,
const std::vector<int>& indices = it.second;
for (size_t j = 0; j < indices.size(); ++j) {
int idx = indices[j];
const T* bdata = bboxes_data + idx * kBBoxSize;
odata[count * kOutputDim] = label; // label
odata[count * kOutputDim + 1] = sdata[idx]; // score
const T* bdata = bboxes_data + idx * box_size;
odata[count * out_dim] = label; // label
odata[count * out_dim + 1] = sdata[idx]; // score
// xmin, ymin, xmax, ymax
std::memcpy(odata + count * kOutputDim + 2, bdata, 4 * sizeof(T));
std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T));
count++;
}
}
......@@ -256,7 +277,8 @@ void MultiClassNMSCompute(const MultiClassNMSParam<CPU>& param) {
float* od = outs->mutable_data<float>({1});
od[0] = -1;
} else {
outs->mutable_data<float>({num_kept, kOutputDim});
int64_t out_dim = box_dim + 2;
outs->mutable_data<float>({num_kept, out_dim});
for (int64_t i = 0; i < batch_size; ++i) {
framework::Tensor ins_score = input_scores->Slice(i, i + 1);
ins_score.Resize({class_num, predict_dim});
......
......@@ -27,13 +27,11 @@ void SumCompute(const SumParam<CPU> &param) {
auto *outvar = param.OutVar();
bool in_place = outvar == inputsvars[0];
DLOG << "11:";
if (outvar->IsType<framework::LoDTensor>()) {
auto *out = outvar->GetMutable<LoDTensor>();
if (!in_place) {
out->mutable_data<float>();
}
DLOG << "1:";
auto *outptr = out->data<float>();
// auto result = Flatten(*out);
......@@ -62,7 +60,6 @@ void SumCompute(const SumParam<CPU> &param) {
}
} else if (outvar->IsType<framework::SelectedRows>()) {
DLOG << "2:";
std::unique_ptr<framework::SelectedRows> in0;
if (in_place) {
// If is in_place, we store the input[0] to in0
......@@ -119,12 +116,12 @@ void SumCompute(const SumParam<CPU> &param) {
if (sel_row.rows().size() == 0) {
continue;
}
PADDLE_MOBILE_ENFORCE(out->height() == sel_row.height());
PADDLE_MOBILE_ENFORCE(out->height() == sel_row.height(),
"seletrows height != outheight");
functor(sel_row, offset, out);
offset += sel_row.value().numel();
}
} else if (outvar->IsType<LoDTensorArray>()) {
DLOG << "3:";
auto &out_array = *outvar->GetMutable<LoDTensorArray>();
for (size_t i = in_place ? 1 : 0; i < inputsvars.size(); ++i) {
PADDLE_MOBILE_ENFORCE(inputsvars[i]->IsType<LoDTensorArray>(),
......@@ -140,7 +137,8 @@ void SumCompute(const SumParam<CPU> &param) {
framework::TensorCopy((*in_array)[i], &out_array[i]);
out_array[i].set_lod((*in_array)[i].lod());
} else {
PADDLE_MOBILE_ENFORCE(out_array[i].lod() == (*in_array)[i].lod());
PADDLE_MOBILE_ENFORCE(out_array[i].lod() == (*in_array)[i].lod(),
"outLod != inLod");
auto *inptr = (*in_array)[i].data<float>();
auto *outptr = out_array[i].data<float>();
......@@ -152,9 +150,7 @@ void SumCompute(const SumParam<CPU> &param) {
}
}
} else {
DLOG << "2:";
if (outvar->IsType<framework::Tensor>()) {
DLOG << "3: ";
}
PADDLE_MOBILE_THROW_EXCEPTION(
"Unexpected branch, output variable type is %s", outvar->Type().name());
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISEADD_OP
#pragma once
#include "framework/operator.h"
#include "operators/math/elementwise_op_function.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class ElementwiseSubKernel
: public framework::OpKernelBase<DeviceType,
ElementwiseSubParam<DeviceType>> {
public:
void Compute(const ElementwiseSubParam<DeviceType> &param) const;
bool Init(ElementwiseSubParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -44,6 +44,7 @@ bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam<FPGA> *param) {
int width = (uint32_t)input_x->dims()[3];
int filter_channel = chw / height / width;
out->Resize(framework::make_ddim({1, channel, 1, 1}));
filter->Resize(framework::make_ddim({num, filter_channel, height, width}));
float max_value = fpga::filter_find_max(filter);
fpga::format_fc_filter(filter, max_value);
......
......@@ -45,6 +45,7 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam<FPGA> *param) {
int width = (uint32_t)input_x->dims()[3];
int filter_channel = chw / height / width;
out->Resize(framework::make_ddim({1, channel, 1, 1}));
filter->Resize(framework::make_ddim({num, filter_channel, height, width}));
float max_value = fpga::filter_find_max(filter);
fpga::format_fc_filter(filter, max_value);
......
......@@ -44,6 +44,7 @@ bool MulKernel<FPGA, float>::Init(MulParam<FPGA> *param) {
int width = (uint32_t)input_x->dims()[3];
int filter_channel = chw / height / width;
out->Resize(framework::make_ddim({1, channel, 1, 1}));
filter->Resize(framework::make_ddim({num, filter_channel, height, width}));
float max_value = fpga::filter_find_max(filter);
fpga::format_fc_filter(filter, max_value);
......
......@@ -27,7 +27,7 @@ bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) {
auto input = const_cast<Tensor *>(param->InputX());
auto input_ptr = input->data<float>();
auto float_input = new Tensor;
float_input->mutable_data<float>(input->dims());
float_input->mutable_data<float>({1, input->dims()[1]});
fpga::format_fp32_ofm(float_input);
fpga::BypassArgs args = {fpga::DATA_TYPE_FP16};
......@@ -56,7 +56,6 @@ void SoftmaxKernel<FPGA, float>::Compute(
fpga::fpga_invalidate(
(void *)in_x->data<float>(), // NOLINT
fpga::get_align_image_cw(in_x->dims()[1]) * sizeof(float));
math::SoftmaxFuntor<CPU, float>()(in_x, out);
fpga::fpga_flush(out->data<float>(), out->memory_size());
}
......
此差异已折叠。
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef MULTICLASSNMS_OP
#pragma once
#include <float.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
namespace gpc {
typedef enum { // Set operation type
GPC_DIFF, // Difference
GPC_INT, // Intersection
GPC_XOR, // Exclusive or
GPC_UNION // Union
} gpc_op;
typedef struct { // Polygon vertex structure
double x; // Vertex x component
double y; // vertex y component
} gpc_vertex;
typedef struct { // Vertex list structure
int num_vertices; // Number of vertices in list
gpc_vertex *vertex; // Vertex array pointer
} gpc_vertex_list;
typedef struct { // Polygon set structure
int num_contours; // Number of contours in polygon
int *hole; // Hole external contour flags
gpc_vertex_list *contour; // Contour array pointer
} gpc_polygon;
typedef struct { // Tristrip set structure
int num_strips; // Number of tristrips
gpc_vertex_list *strip; // Tristrip array pointer
} gpc_tristrip;
typedef enum { LEFT, RIGHT } gpc_left_right;
typedef enum { ABOVE, BELOW } gpc_above_below;
typedef enum { CLIP, SUBJ } gpc_clip_subj;
typedef enum { /* Edge intersection classes */
NUL, /* Empty non-intersection */
EMX, /* External maximum */
ELI, /* External left intermediate */
TED, /* Top edge */
ERI, /* External right intermediate */
RED, /* Right edge */
IMM, /* Internal maximum and minimum */
IMN, /* Internal minimum */
EMN, /* External minimum */
EMM, /* External maximum and minimum */
LED, /* Left edge */
ILI, /* Internal left intermediate */
BED, /* Bottom edge */
IRI, /* Internal right intermediate */
IMX, /* Internal maximum */
FUL /* Full non-intersection */
} vertex_type;
typedef enum { /* Horizontal edge states */
NH, /* No horizontal edge */
BH, /* Bottom horizontal edge */
TH /* Top horizontal edge */
} h_state;
typedef enum { /* Edge bundle state */
UNBUNDLED, /* Isolated edge not within a bundle */
BUNDLE_HEAD, /* Bundle head node */
BUNDLE_TAIL /* Passive bundle tail node */
} bundle_state;
typedef struct v_shape { /* Internal vertex list datatype */
double x; /* X coordinate component */
double y; /* Y coordinate component */
struct v_shape *next; /* Pointer to next vertex in list */
} vertex_node;
typedef struct p_shape { /* Internal contour / tristrip type */
int active; /* Active flag / vertex count */
int hole; /* Hole / external contour flag */
vertex_node *v[2]; /* Left and right vertex list ptrs */
struct p_shape *next; /* Pointer to next polygon contour */
struct p_shape *proxy; /* Pointer to actual structure used */
} polygon_node;
typedef struct edge_shape {
gpc_vertex vertex; /* Piggy-backed contour vertex data */
gpc_vertex bot; /* Edge lower (x, y) coordinate */
gpc_vertex top; /* Edge upper (x, y) coordinate */
double xb; /* Scanbeam bottom x coordinate */
double xt; /* Scanbeam top x coordinate */
double dx; /* Change in x for a unit y increase */
int type; /* Clip / subject edge flag */
int bundle[2][2]; /* Bundle edge flags */
int bside[2]; /* Bundle left / right indicators */
bundle_state bstate[2]; /* Edge bundle state */
polygon_node *outp[2]; /* Output polygon / tristrip pointer */
struct edge_shape *prev; /* Previous edge in the AET */
struct edge_shape *next; /* Next edge in the AET */
struct edge_shape *pred; /* Edge connected at the lower end */
struct edge_shape *succ; /* Edge connected at the upper end */
struct edge_shape *next_bound; /* Pointer to next bound in LMT */
} edge_node;
inline bool gpc_eq(float a, float b) { return (fabs(a - b) <= 1e-6); }
inline bool gpc_prev_index(float a, float b) { return (fabs(a - b) <= 1e-6); }
inline int gpc_prev_index(int i, int n) { return ((i - 1 + n) % n); }
inline int gpc_next_index(int i, int n) { return ((i + 1) % n); }
inline int gpc_optimal(gpc_vertex *v, int i, int n) {
return (v[(i + 1) % n].y != v[i].y || v[(i - 1 + n) % n].y != v[i].y);
}
inline int gpc_fwd_min(edge_node *v, int i, int n) {
return (v[(i + 1) % n].vertex.y > v[i].vertex.y &&
v[(i - 1 + n) % n].vertex.y >= v[i].vertex.y);
}
inline int gpc_not_fmax(edge_node *v, int i, int n) {
return (v[(i + 1) % n].vertex.y > v[i].vertex.y);
}
inline int gpc_rev_min(edge_node *v, int i, int n) {
return (v[(i + 1) % n].vertex.y >= v[i].vertex.y &&
v[(i - 1 + n) % n].vertex.y > v[i].vertex.y);
}
inline int gpc_not_rmax(edge_node *v, int i, int n) {
return (v[(i - 1 + n) % n].vertex.y > v[i].vertex.y);
}
// inline void gpc_p_edge(edge_node *d, edge_node *e, int p, double i, double j)
// {
inline void gpc_p_edge(edge_node *d, edge_node *e, int p) {
d = e;
do {
d = d->prev;
} while (!d->outp[p]);
// i = d->bot.x + d->dx * (j - d->bot.y);
}
// inline void gpc_n_edge(edge_node *d, edge_node *e, int p, double i, double j)
// {
inline void gpc_n_edge(edge_node *d, edge_node *e, int p) {
d = e;
do {
d = d->next;
} while (!d->outp[p]);
// i = d->bot.x + d->dx * (j - d->bot.y);
}
template <typename T>
void gpc_malloc(T *&p, int b, char *s) { // NOLINT
if (b > 0) {
p = reinterpret_cast<T *>(malloc(b));
if (!p) {
fprintf(stderr, "gpc malloc failure: %s\n", s);
exit(0);
}
} else {
p = NULL;
}
}
template <typename T>
void gpc_free(T *&p) { // NOLINT
if (p) {
free(p);
p = NULL;
}
}
/*
===========================================================================
Public Function Prototypes
===========================================================================
*/
void add_vertex(vertex_node **t, double x, double y);
void gpc_vertex_create(edge_node *e, int p, int s, double x, double y);
void gpc_add_contour(gpc_polygon *polygon, gpc_vertex_list *contour, int hole);
void gpc_polygon_clip(gpc_op set_operation, gpc_polygon *subject_polygon,
gpc_polygon *clip_polygon, gpc_polygon *result_polygon);
void gpc_tristrip_clip(gpc_op set_operation, gpc_polygon *subject_polygon,
gpc_polygon *clip_polygon,
gpc_tristrip *result_tristrip);
void gpc_polygon_to_tristrip(gpc_polygon *polygon, gpc_tristrip *tristrip);
void gpc_free_polygon(gpc_polygon *polygon);
void gpc_free_tristrip(gpc_tristrip *tristrip);
} // namespace gpc
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef MULTICLASSNMS_OP
#include "operators/math/poly_util.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <class T>
void Array2PointVec(const T* box, const size_t box_size,
std::vector<Point_<T>>* vec) {
size_t pts_num = box_size / 2;
vec->resize(pts_num);
for (size_t i = 0; i < pts_num; i++) {
vec->at(i).x = box[2 * i];
vec->at(i).y = box[2 * i + 1];
}
}
template <class T>
void Array2Poly(const T* box, const size_t box_size, gpc::gpc_polygon* poly) {
size_t pts_num = box_size / 2;
poly->num_contours = 1;
poly->hole = reinterpret_cast<int*>(malloc(sizeof(int)));
poly->hole[0] = 0;
poly->contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list));
poly->contour->num_vertices = pts_num;
poly->contour->vertex =
(gpc::gpc_vertex*)malloc(sizeof(gpc::gpc_vertex) * pts_num);
for (size_t i = 0; i < pts_num; ++i) {
poly->contour->vertex[i].x = box[2 * i];
poly->contour->vertex[i].y = box[2 * i + 1];
}
}
template void Array2Poly(const float* box, const size_t box_size,
gpc::gpc_polygon* poly);
template <class T>
void Poly2PointVec(const gpc::gpc_vertex_list& contour,
std::vector<Point_<T>>* vec) {
int pts_num = contour.num_vertices;
vec->resize(pts_num);
for (size_t i = 0; i < pts_num; i++) {
vec->at(i).x = contour.vertex[i].x;
vec->at(i).y = contour.vertex[i].y;
}
}
template <class T>
T GetContourArea(const std::vector<Point_<T>>& vec) {
int pts_num = vec.size();
if (pts_num < 3) return T(0.);
T area = T(0.);
for (size_t i = 0; i < pts_num; ++i) {
area += vec[i].x * vec[(i + 1) % pts_num].y -
vec[i].y * vec[(i + 1) % pts_num].x;
}
return fabs(area / 2.0);
}
template <class T>
T PolyArea(const T* box, const size_t box_size, const bool normalized) {
// If coordinate values are is invalid
// if area size <= 0, return 0.
std::vector<Point_<T>> vec;
Array2PointVec<T>(box, box_size, &vec);
return GetContourArea<T>(vec);
}
template float PolyArea(const float* box, const size_t box_size,
const bool normalized);
template <class T>
T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size,
const bool normalized) {
gpc::gpc_polygon poly1;
gpc::gpc_polygon poly2;
Array2Poly<T>(box1, box_size, &poly1);
Array2Poly<T>(box2, box_size, &poly2);
gpc::gpc_polygon respoly;
gpc::gpc_op op = gpc::GPC_INT;
gpc::gpc_polygon_clip(op, &poly2, &poly1, &respoly);
T inter_area = T(0.);
int contour_num = respoly.num_contours;
for (int i = 0; i < contour_num; ++i) {
std::vector<Point_<T>> resvec;
Poly2PointVec<T>(respoly.contour[i], &resvec);
inter_area += GetContourArea<T>(resvec);
}
gpc::gpc_free_polygon(&poly1);
gpc::gpc_free_polygon(&poly2);
gpc::gpc_free_polygon(&respoly);
return inter_area;
}
template float PolyOverlapArea(const float* box1, const float* box2,
const size_t box_size, const bool normalized);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef MULTICLASSNMS_OP
#pragma once
#include <vector>
#include "operators/math/gpc.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <class T>
class Point_ {
public:
// default constructor
Point_() {}
Point_(T _x, T _y) {}
Point_(const Point_& pt) {}
Point_& operator=(const Point_& pt);
// conversion to another data type
// template<typename _T> operator Point_<_T>() const;
// conversion to the old-style C structures
// operator Vec<T, 2>() const;
// checks whether the point is inside the specified rectangle
// bool inside(const Rect_<T>& r) const;
T x; //!< x coordinate of the point
T y; //!< y coordinate of the point
};
template <class T>
void Array2PointVec(const T* box, const size_t box_size,
std::vector<Point_<T>>* vec);
template <class T>
void Array2Poly(const T* box, const size_t box_size, gpc::gpc_polygon* poly);
template <class T>
void Poly2PointVec(const gpc::gpc_vertex_list& contour,
std::vector<Point_<T>>* vec);
template <class T>
T GetContourArea(const std::vector<Point_<T>>& vec);
template <class T>
T PolyArea(const T* box, const size_t box_size, const bool normalized);
template <class T>
T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size,
const bool normalized);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -47,7 +47,7 @@ struct SelectedRowsAddTo {
const int64_t input2_offset,
framework::SelectedRows* input2) {
auto in1_height = input1.height();
PADDLE_MOBILE_ENFORCE(in1_height == input2->height());
PADDLE_MOBILE_ENFORCE(in1_height == input2->height(), "height error");
auto& in1_rows = input1.rows();
auto& in2_rows = *(input2->mutable_rows());
......@@ -77,13 +77,14 @@ struct SelectedRowsAddToTensor {
framework::Tensor* input2) {
auto in1_height = input1.height();
auto in2_dims = input2->dims();
PADDLE_MOBILE_ENFORCE(in1_height == in2_dims[0]);
PADDLE_MOBILE_ENFORCE(in1_height == in2_dims[0], "height != dims[0]");
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_MOBILE_ENFORCE(in1_row_numel == input2->numel() / in1_height);
PADDLE_MOBILE_ENFORCE(in1_row_numel == input2->numel() / in1_height,
"row_numel error");
auto* in1_data = in1_value.data<T>();
auto* input2_data = input2->data<T>();
......
......@@ -25,8 +25,8 @@ void MultiClassNMSOp<Dtype, T>::InferShape() const {
if (input_scores_dims.size() != 3) {
LOG(kLOG_ERROR) << "Input Scores size must be 3";
}
if (input_bboxes_dims[2] != 4) {
LOG(kLOG_ERROR) << "Input BBoxes 2nd dimension must be 4";
if (input_bboxes_dims[2] % 4 != 0 || input_bboxes_dims[2] < 4) {
LOG(kLOG_ERROR) << "Input BBoxes 2nd dimension must be multiples of 4";
}
if (input_bboxes_dims[1] != input_scores_dims[2]) {
LOG(kLOG_ERROR) << "Predict bboxes must be equal";
......
......@@ -471,15 +471,6 @@ class ElementwiseMulParam : OpParam {
GType *input_y_;
GType *out_;
int axis_;
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::EWMulArgs fpga_EW_mul_args;
public:
const fpga::EWMulArgs &FpgaArgs() const { return fpga_EW_mul_args; }
void SetFpgaArgs(const fpga::EWMulArgs &args) { fpga_EW_mul_args = args; }
#endif
};
#endif
......@@ -488,6 +479,38 @@ template <typename Dtype>
using ElementwiseAddReluParam = ElementwiseAddParam<Dtype>;
#endif
#ifdef ELEMENTWISESUB_OP
template <typename Dtype>
class ElementwiseSubParam : OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
ElementwiseSubParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
input_y_ = InputYFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
axis_ = GetAttr<int>("axis", attrs);
}
const GType *InputX() const { return input_x_; }
const GType *InputY() const { return input_y_; }
GType *Out() const { return out_; }
const int &Axis() const { return axis_; }
private:
GType *input_x_;
GType *input_y_;
GType *out_;
int axis_;
};
#endif
#ifdef MUL_OP
template <typename Dtype>
class MulParam : OpParam {
......@@ -596,15 +619,6 @@ class SumParam : public OpParam {
Variable *out_var_;
vector<GType *> inputs_;
GType *out_;
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::SumArgs fpga_sum_args;
public:
const fpga::SumArgs &FpgaArgs() const { return fpga_sum_args; }
void SetFpgaArgs(const fpga::SumArgs &args) { fpga_sum_args = args; }
#endif
};
#endif
......
......@@ -173,6 +173,14 @@ if (NOT FOUND_MATCH)
target_link_libraries(test-elementwiseadd-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-elementwisesub-op operators/test_elementwise_sub_op.cpp test_helper.h test_include.h)
target_link_libraries(test-elementwisesub-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-im2sequence-op operators/test_im2sequence_op.cpp test_helper.h test_include.h)
target_link_libraries(test-im2sequence-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-concat-op operators/test_concat_op.cpp test_helper.h test_include.h)
target_link_libraries(test-concat-op paddle-mobile)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/elementwise_sub_op.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype>
class TestElementwiseSubOp {
public:
explicit TestElementwiseSubOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
// DLOG << " **block size " << blocks.size();
for (int i = 0; i < blocks.size(); ++i) {
std::shared_ptr<BlockDesc> block_desc = blocks[i];
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// DLOG << " ops " << ops.size();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<OpDesc> op = ops[j];
if (op->Type() == "elementwise_sub" &&
op->Input("X")[0] == "sigmoid_1.tmp_0") {
DLOG << " elementwise_sub attr size: " << op->GetAttrMap().size();
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " outputs size: " << op->GetOutputs().size();
std::shared_ptr<operators::ElementwiseSubOp<Dtype, float>> lrn =
std::make_shared<operators::ElementwiseSubOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(lrn);
}
}
}
}
std::shared_ptr<Tensor> predict_bn(const Tensor &t1, const Tensor &t2) {
// feed
auto scope = program_.scope;
Variable *x1_feed_value = scope->Var("tmp_0");
auto tensor_x1 = x1_feed_value->GetMutable<LoDTensor>();
tensor_x1->ShareDataWith(t1);
Variable *x2_feed_value = scope->Var("sigmoid_1.tmp_0");
auto tensor_x2 = x2_feed_value->GetMutable<LoDTensor>();
tensor_x2->ShareDataWith(t2);
Variable *output = scope->Var("tmp_1");
auto *output_tensor = output->GetMutable<LoDTensor>();
output_tensor->mutable_data<float>({1, 1, 6, 6});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict_bn(t1, t2, 0);
return out_tensor;
}
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
void predict_bn(const Tensor &t1, const Tensor &t2, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()";
op->Run();
}
}
};
template class TestElementwiseSubOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run ElementwiseSub Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_ocr) + "/model",
std::string(g_ocr) + "/params");
/// input x1 (1,1,6,6)
paddle_mobile::framework::Tensor inputx1;
SetupTensor<float>(&inputx1, {1, 1, 6, 6}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx1_ptr = inputx1.data<float>();
/// input x2 (1,1,6,6)
paddle_mobile::framework::Tensor inputx2;
SetupTensor<float>(&inputx2, {1, 1, 6, 6}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx2_ptr = inputx2.data<float>();
paddle_mobile::framework::TestElementwiseSubOp<paddle_mobile::CPU>
testElementwiseSubOp(program);
auto output_op = testElementwiseSubOp.predict_bn(inputx1, inputx2);
auto *output_op_ptr = output_op->data<float>();
auto inputx1_dim = inputx1.numel() / inputx1.dims()[0];
DLOG << " input1 : ";
for (int i = 0; i < inputx1.dims()[0]; ++i) {
for (int j = 0; j < inputx1_dim; ++j) {
DLOGF("%f ", inputx1_ptr[i * inputx1_dim + j]);
}
DLOGF("\n");
}
auto inputx2_dim = inputx2.numel() / inputx2.dims()[0];
DLOG << " input2 : ";
for (int i = 0; i < inputx2.dims()[0]; ++i) {
for (int j = 0; j < inputx2_dim; ++j) {
DLOGF("%f ", inputx2_ptr[i * inputx2_dim + j]);
}
DLOGF("\n");
}
auto output_dim = output_op->numel() / output_op->dims()[0];
DLOG << " output : ";
for (int i = 0; i < output_op->dims()[0]; ++i) {
for (int j = 0; j < output_dim; ++j) {
DLOGF("%f ", output_op_ptr[i * output_dim + j]);
}
DLOGF("\n");
}
return 0;
}
......@@ -12,51 +12,129 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../executor_for_test.h"
#pragma once
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/im2sequence_op.h"
int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(g_ocr_recg);
PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr,
"program file read fail");
namespace paddle_mobile {
namespace framework {
Executor4Test<paddle_mobile::CPU,
paddle_mobile::operators::ReluOp<paddle_mobile::CPU, float>>
executor(program, "im2sequence");
template <typename Dtype>
class TestIm2SequenceOp {
public:
explicit TestIm2SequenceOp(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
// 1. input_tensors;
vector<Tensor> input_tensors;
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
// DLOG << " **block size " << blocks.size();
for (int i = 0; i < blocks.size(); ++i) {
std::shared_ptr<BlockDesc> block_desc = blocks[i];
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// DLOG << " ops " << ops.size();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<OpDesc> op = ops[j];
if (op->Type() == "im2sequence" &&
op->Input("X")[0] == "conv2d_19.tmp_1") {
DLOG << " im2squence attr size: " << op->GetAttrMap().size();
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " outputs size: " << op->GetOutputs().size();
Tensor input1;
auto input1_data = CreateInput<float>(&input1, {2, 2, 3, 3}, -1, 1);
input_tensors.push_back(input1);
std::shared_ptr<operators::Im2SequenceOp<Dtype, float>> lrn =
std::make_shared<operators::Im2SequenceOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(lrn);
}
}
}
}
// 2. input_names
vector<string> input_names({
"conv2d_19.tmp_1",
});
std::shared_ptr<Tensor> predict_bn(const Tensor &t1) {
// feed
auto scope = program_.scope;
Variable *x1_feed_value = scope->Var("conv2d_19.tmp_1");
auto tensor_x1 = x1_feed_value->GetMutable<LoDTensor>();
tensor_x1->ShareDataWith(t1);
// 3. output_names
vector<string> output_names({"im2sequence_0.tmp_0"});
Variable *output = scope->Var("im2sequence_0.tmp_0");
auto *output_tensor = output->GetMutable<LoDTensor>();
output_tensor->mutable_data<float>({2, 12});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
// 4. out_dims;
vector<DDim> out_ddims;
auto out_ddim = paddle_mobile::framework::make_ddim({8, 9});
out_ddims.push_back(out_ddim);
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
auto output = executor.Predict<LoDTensor>(input_tensors, input_names,
output_names, out_ddims);
predict_bn(t1, 0);
return out_tensor;
}
auto output0_data = output[0]->data<float>();
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
for (int j = 0; j < input_tensors[0].numel(); ++j) {
DLOG << " value of input: " << input1_data[j];
void predict_bn(const Tensor &t1, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
DLOG << "op -> run()";
op->Run();
}
}
};
template class TestIm2SequenceOp<CPU>;
} // namespace framework
} // namespace paddle_mobile
for (int j = 0; j < output[0]->numel(); ++j) {
DLOG << " value of output: " << output0_data[j];
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run Im2Sequence Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_eng) + "/model",
std::string(g_eng) + "/params");
/// input x (4,10,2,2)
paddle_mobile::framework::Tensor inputx;
SetupTensor<float>(&inputx, {1, 2, 6, 2}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx_ptr = inputx.data<float>();
paddle_mobile::framework::TestIm2SequenceOp<paddle_mobile::CPU>
testIm2SequenceOp(program);
auto output_op = testIm2SequenceOp.predict_bn(inputx);
auto *output_op_ptr = output_op->data<float>();
auto input_dim = inputx.numel() / inputx.dims()[0];
DLOG << " input : ";
for (int i = 0; i < inputx.dims()[0]; ++i) {
for (int j = 0; j < input_dim; ++j) {
DLOGF("%f ", inputx_ptr[i * input_dim + j]);
}
DLOGF("\n");
}
auto output_dim = output_op->numel() / output_op->dims()[0];
DLOG << " output : ";
for (int i = 0; i < output_op->dims()[0]; ++i) {
for (int j = 0; j < output_dim; ++j) {
DLOGF("%f ", output_op_ptr[i * output_dim + j]);
}
DLOGF("\n");
}
return 0;
}
......@@ -127,18 +127,25 @@ int main() {
DLOG << "----------**********----------";
DLOG << "begin to run MulticlassNMS Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string("../../test/models/mobilenet+ssd"));
auto program = loader.Load(std::string(g_mobilenet_ssd));
/// input x (1,3,300,300)
paddle_mobile::framework::Tensor inputx1;
SetupTensor<float>(&inputx1, {10, 1917, 4}, static_cast<float>(0),
SetupTensor<float>(&inputx1, {1, 2, 4}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx1_ptr = inputx1.data<float>();
const float x1[] = {0, 0, 100, 100, 50, 50, 150, 150};
for (int i = 0; i < 8; ++i) {
*(inputx1_ptr + i) = x1[i];
}
paddle_mobile::framework::Tensor inputx2;
SetupTensor<float>(&inputx2, {10, 21, 1917}, static_cast<float>(0),
SetupTensor<float>(&inputx2, {1, 2, 2}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx2_ptr = inputx2.data<float>();
const float x2[] = {0.4, 0.3, 0.6, 0.7};
for (int i = 0; i < 4; ++i) {
*(inputx2_ptr + i) = x2[i];
}
paddle_mobile::framework::TestMultiClassNMSOp<paddle_mobile::CPU>
testMultiClassNMSOp(program);
......@@ -146,8 +153,26 @@ int main() {
auto output = testMultiClassNMSOp.predict(inputx1, inputx2);
auto *output_ptr = output->data<float>();
for (int i = 0; i < output->numel(); i++) {
for (int i = 0; i < output->numel(); ++i) {
DLOG << output_ptr[i];
}
// test multi point
paddle_mobile::framework::Tensor inputx3;
SetupTensor<float>(&inputx3, {1, 2, 8}, static_cast<float>(0),
static_cast<float>(1));
auto *inputx3_ptr = inputx3.data<float>();
const float x3[] = {0, 0, 100, 0, 100, 100, 0, 100,
50, 50, 150, 50, 150, 150, 50, 150};
for (int i = 0; i < 16; ++i) {
*(inputx3_ptr + i) = x3[i];
}
auto output2 = testMultiClassNMSOp.predict(inputx3, inputx2);
auto *output_ptr2 = output2->data<float>();
for (int i = 0; i < output2->numel(); ++i) {
DLOG << output_ptr2[i];
}
return 0;
}
......@@ -189,6 +189,8 @@ if(NOT FOUND_MATCH)
set(CONV_OP ON)
set(DEPTHWISECONV_OP ON)
set(ELEMENTWISEADD_OP ON)
set(ELEMENTWISESUB_OP ON)
set(IM2SEQUENCE_OP ON)
set(FUSION_CONVADD_OP ON)
set(FUSION_CONVADDPRELU_OP ON)
set(FUSION_CONVADDRELU_OP ON)
......@@ -264,6 +266,9 @@ endif()
if (ELEMENTWISEADD_OP)
add_definitions(-DELEMENTWISEADD_OP)
endif()
if (ELEMENTWISESUB_OP)
add_definitions(-DELEMENTWISESUB_OP)
endif()
if (FUSION_CONVADD_OP)
add_definitions(-DFUSION_CONVADD_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册