diff --git a/src/fpga/api.cpp b/src/fpga/api.cpp index 138906c790574a4a0201180b5d18cd67960a7e1d..725895ae6a3da161af545646c2a74bda16be532f 100644 --- a/src/fpga/api.cpp +++ b/src/fpga/api.cpp @@ -22,7 +22,7 @@ limitations under the License. */ #include "fpga/filter.h" #include "fpga/image.h" #define FPGA_TEST_MODE -#define PADDLE_MOBILE_OS_LINUX +//#define PADDLE_MOBILE_OS_LINUX namespace paddle_mobile { namespace fpga { @@ -125,6 +125,7 @@ float fp16_2_fp32(half fp16_num) { } int ComputeBasicConv(const struct ConvArgs &args) { +#ifdef FPGA_TEST_MODE DLOG << "======Compute Basic Conv======"; DLOG << " relu_enabled:" << args.relu_enabled << " sb_address:" << args.sb_address @@ -144,7 +145,7 @@ int ComputeBasicConv(const struct ConvArgs &args) { << " stride_w:" << args.kernel.stride_w; DLOG << " out_address:" << args.output.address << " out_scale_address:" << args.output.scale_address; - +#endif return do_ioctl(IOCTL_CONFIG_CONV, &args); } @@ -192,8 +193,9 @@ int ComputeFpgaPool(const struct PoolingArgs &args) { int ComputeFpgaEWAdd(const struct EWAddArgs &args) { #ifdef FPGA_TEST_MODE DLOG << "=============ComputeFpgaEWAdd==========="; - DLOG << " relu_enabled:" << args.relu_enabled << " const0:" << args.const0 - << " const1:" << args.const1; + DLOG << " relu_enabled:" << args.relu_enabled + << " const0:" << fp16_2_fp32(short(args.const0)) + << " const1:" << fp16_2_fp32(short(args.const1)); DLOG << " image0_address:" << args.image0.address << " image0_scale_address:" << args.image0.scale_address << " image0_channels:" << args.image0.channels @@ -401,8 +403,8 @@ void fill_conv_arg(struct WrapperConvArgs *arg, framework::Tensor *input, arg->concat_arg.image_num = arg->split_num; arg->concat_arg.image_out = out_ptr; arg->concat_arg.scale_out = out->scale; - arg->concat_arg.height = (uint32_t)filter->dims()[2]; - arg->concat_arg.width = (uint32_t)filter->dims()[3]; + arg->concat_arg.height = (uint32_t)out->dims()[2]; + arg->concat_arg.width = (uint32_t)out->dims()[3]; int n = arg->split_num; arg->concat_arg.images_in = @@ -411,7 +413,6 @@ void fill_conv_arg(struct WrapperConvArgs *arg, framework::Tensor *input, (float **)fpga_malloc(n * sizeof(float *)); // NOLINT arg->concat_arg.channel_num = (uint32_t *)fpga_malloc(n * sizeof(uint32_t)); // NOLINT - arg->concat_arg.image_out = out_ptr; auto channel = (int)out->dims()[1]; // NOLINT int filter_num_per_div = get_filter_num_per_div(filter, group_num); diff --git a/src/fpga/bias_scale.cpp b/src/fpga/bias_scale.cpp index 50f1ed03f0121b5afdc41d427e5b52675994bd1e..23889d5b1fee3d8cb9e4673f42b18574366411eb 100644 --- a/src/fpga/bias_scale.cpp +++ b/src/fpga/bias_scale.cpp @@ -27,6 +27,9 @@ void align_element(float **data_in, int num_per_div_before_alignment, int num) { (num + num_per_div_before_alignment - 1) / num_per_div_before_alignment; int num_per_div_after_alignment = align_to_x(num_per_div_before_alignment, BS_NUM_ALIGNMENT); + if (num_per_div_before_alignment == num_per_div_after_alignment) { + return; + } int num_element = 2 * div_num * num_per_div_after_alignment; // including bias & scale float *ptr_aligned = diff --git a/src/fpga/filter.cpp b/src/fpga/filter.cpp index 34e0ad6f18f8e80d636e42630e03650c018a8825..c824b446ce3a4c3f13ad788780997a3920a1484c 100644 --- a/src/fpga/filter.cpp +++ b/src/fpga/filter.cpp @@ -210,12 +210,12 @@ void format_filter(float **data_in, int num, int channel, int height, int width, align_to_x(num_per_div_before_alignment, FILTER_NUM_ALIGNMENT); int div_num = (num + num_per_div_before_alignment - 1) / num_per_div_before_alignment; - int num_after_alignment = num_per_div_after_alignment * div_num; - + int residual = num % num_per_div_before_alignment; + int num_after_alignment = num_per_div_after_alignment * + ((residual == 0) ? div_num : (div_num - 1)) + + align_to_x(residual, FILTER_NUM_ALIGNMENT); quantize(data_in, data_size, max); - char **quantize_data = (char **)data_in; // NOLINT - convert_to_hwc(quantize_data, num, channel, height, width); align_element(quantize_data, num, chw); align_num(quantize_data, num_per_div_before_alignment, num, chw); diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index a2a6da34849641b4f99310621445cb312c7d5227..03fdd8d433cd40aa7ba4786f02221bd24bd3a050 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -199,6 +199,12 @@ LOAD_OP3(pool2d, CPU, MALI_GPU, FPGA); #ifdef MULTICLASSNMS_OP LOAD_OP1(multiclass_nms, CPU); #endif +#ifdef SUM_OP +LOAD_OP1(sum, CPU); +#endif +#ifdef ELEMENTWISEMUL_OP +LOAD_OP1(elementwise_mul, CPU); +#endif #ifdef SLICE_OP LOAD_OP2(slice, CPU, MALI_GPU); #endif @@ -206,5 +212,8 @@ LOAD_OP2(slice, CPU, MALI_GPU); LOAD_OP2(fusion_conv_bn, CPU, FPGA); LOAD_FUSION_MATCHER(fusion_conv_bn); #endif +#ifdef ELEMENTWISESUB_OP +LOAD_OP1(elementwise_sub, CPU) +#endif LOAD_OP1(quantize, CPU); LOAD_OP1(dequantize, CPU); diff --git a/src/operators/elementwise_sub_op.cpp b/src/operators/elementwise_sub_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5ec33ced29f02a524350ed907ef69f2a5dbfca8 --- /dev/null +++ b/src/operators/elementwise_sub_op.cpp @@ -0,0 +1,41 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef ELEMENTWISESUB_OP + +#include "operators/elementwise_sub_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void ElementwiseSubOp::InferShape() const { + auto x_dim = this->param_.InputX()->dims(); + this->param_.Out()->Resize(x_dim); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(elementwise_sub, ops::ElementwiseSubOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +REGISTER_OPERATOR_MALI_GPU(elementwise_sub, ops::ElementwiseSubOp); +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#endif diff --git a/src/operators/elementwise_sub_op.h b/src/operators/elementwise_sub_op.h new file mode 100644 index 0000000000000000000000000000000000000000..2edd2581a9d3929a29459df60f514132796a53e2 --- /dev/null +++ b/src/operators/elementwise_sub_op.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef ELEMENTWISESUB_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "kernel/elementwise_sub_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { +using std::string; +template +class ElementwiseSubOp : public framework::OperatorWithKernel< + DeviceType, ElementwiseSubParam, + operators::ElementwiseSubKernel> { + public: + ElementwiseSubOp(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, ElementwiseSubParam, + operators::ElementwiseSubKernel>( + type, inputs, outputs, attrs, scope) {} + + using framework::OperatorWithKernel< + DeviceType, ElementwiseSubParam, + operators::ElementwiseSubKernel>::OperatorWithKernel; + void InferShape() const override; + + protected: +}; +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/elementwise_sub_kernel.cpp b/src/operators/kernel/arm/elementwise_sub_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d78b3e31098ef7ef929a0d2c00043fab7193b01c --- /dev/null +++ b/src/operators/kernel/arm/elementwise_sub_kernel.cpp @@ -0,0 +1,38 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef ELEMENTWISESUB_OP + +#include "operators/kernel/elementwise_sub_kernel.h" +#include "operators/kernel/central-arm-func/elementwise_sub_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool ElementwiseSubKernel::Init(ElementwiseSubParam *param) { + return true; +} + +template <> +void ElementwiseSubKernel::Compute( + const ElementwiseSubParam ¶m) const { + ElementwiseSubCompute(param); + param.Out()->set_lod(param.InputX()->lod()); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/elementwise_sub_arm_func.h b/src/operators/kernel/central-arm-func/elementwise_sub_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..663c65c83a0f5b76e292925ea8cb0994b0f99ad1 --- /dev/null +++ b/src/operators/kernel/central-arm-func/elementwise_sub_arm_func.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef ELEMENTWISESUB_OP + +#pragma once +#include "operators/math/elementwise_op_function.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +struct SubFunctor { + inline T operator()(T a, T b) const { return a - b; } +}; + +template +void ElementwiseSubCompute(const ElementwiseSubParam ¶m) { + const Tensor *input_x = param.InputX(); + const Tensor *input_y = param.InputY(); + Tensor *Out = param.Out(); + Out->mutable_data(); + int axis = param.Axis(); + ElementwiseComputeEx, float>(input_x, input_y, axis, + SubFunctor(), Out); +} + +template class ElementwiseSubKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/multiclass_nms_arm_func.h b/src/operators/kernel/central-arm-func/multiclass_nms_arm_func.h index 9de57910540b4c9f7ab807053add9c5af9947ae7..533edd69b6160115fb81066cb1928fb4246ca5be 100644 --- a/src/operators/kernel/central-arm-func/multiclass_nms_arm_func.h +++ b/src/operators/kernel/central-arm-func/multiclass_nms_arm_func.h @@ -20,14 +20,12 @@ limitations under the License. */ #include #include #include "framework/tensor.h" +#include "operators/math/poly_util.h" #include "operators/op_param.h" namespace paddle_mobile { namespace operators { -constexpr int kOutputDim = 6; -constexpr int kBBoxSize = 4; - template bool SortScorePairDescend(const std::pair& pair1, const std::pair& pair2) { @@ -90,6 +88,21 @@ static inline T JaccardOverlap(const T* box1, const T* box2, } } +template +static inline T PolyIoU(const T* box1, const T* box2, const size_t box_size, + const bool normalized) { + T bbox1_area = math::PolyArea(box1, box_size, normalized); + T bbox2_area = math::PolyArea(box2, box_size, normalized); + T inter_area = math::PolyOverlapArea(box1, box2, box_size, normalized); + if (bbox1_area == 0 || bbox2_area == 0 || inter_area == 0) { + // If coordinate values are is invalid + // if area size <= 0, return 0. + return static_cast(0.); + } else { + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + template static inline void NMSFast(const framework::Tensor& bbox, const framework::Tensor& scores, @@ -116,8 +129,14 @@ static inline void NMSFast(const framework::Tensor& bbox, for (size_t k = 0; k < selected_indices->size(); ++k) { if (keep) { const int kept_idx = (*selected_indices)[k]; - T overlap = JaccardOverlap(bbox_data + idx * box_size, + T overlap = T(0.); + if (box_size == 4) { + overlap = JaccardOverlap(bbox_data + idx * box_size, bbox_data + kept_idx * box_size, true); + } else { + overlap = PolyIoU(bbox_data + idx * box_size, + bbox_data + kept_idx * box_size, box_size, true); + } keep = overlap <= adaptive_threshold; } else { break; @@ -190,6 +209,8 @@ void MultiClassOutput(const framework::Tensor& scores, const std::map>& selected_indices, framework::Tensor* outs) { int predict_dim = scores.dims()[1]; + int box_size = bboxes.dims()[1]; + int out_dim = bboxes.dims()[1] + 2; auto* scores_data = scores.data(); auto* bboxes_data = bboxes.data(); auto* odata = outs->data(); @@ -202,11 +223,11 @@ void MultiClassOutput(const framework::Tensor& scores, const std::vector& indices = it.second; for (size_t j = 0; j < indices.size(); ++j) { int idx = indices[j]; - const T* bdata = bboxes_data + idx * kBBoxSize; - odata[count * kOutputDim] = label; // label - odata[count * kOutputDim + 1] = sdata[idx]; // score + const T* bdata = bboxes_data + idx * box_size; + odata[count * out_dim] = label; // label + odata[count * out_dim + 1] = sdata[idx]; // score // xmin, ymin, xmax, ymax - std::memcpy(odata + count * kOutputDim + 2, bdata, 4 * sizeof(T)); + std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T)); count++; } } @@ -256,7 +277,8 @@ void MultiClassNMSCompute(const MultiClassNMSParam& param) { float* od = outs->mutable_data({1}); od[0] = -1; } else { - outs->mutable_data({num_kept, kOutputDim}); + int64_t out_dim = box_dim + 2; + outs->mutable_data({num_kept, out_dim}); for (int64_t i = 0; i < batch_size; ++i) { framework::Tensor ins_score = input_scores->Slice(i, i + 1); ins_score.Resize({class_num, predict_dim}); diff --git a/src/operators/kernel/central-arm-func/sum_arm_func.h b/src/operators/kernel/central-arm-func/sum_arm_func.h index 0319f2b23418f36670ca51993e97726879f12ec1..25c1c51c7abd62a900665197ab4e221b76a3fa04 100644 --- a/src/operators/kernel/central-arm-func/sum_arm_func.h +++ b/src/operators/kernel/central-arm-func/sum_arm_func.h @@ -27,13 +27,11 @@ void SumCompute(const SumParam ¶m) { auto *outvar = param.OutVar(); bool in_place = outvar == inputsvars[0]; - DLOG << "11:"; if (outvar->IsType()) { auto *out = outvar->GetMutable(); if (!in_place) { out->mutable_data(); } - DLOG << "1:"; auto *outptr = out->data(); // auto result = Flatten(*out); @@ -62,7 +60,6 @@ void SumCompute(const SumParam ¶m) { } } else if (outvar->IsType()) { - DLOG << "2:"; std::unique_ptr in0; if (in_place) { // If is in_place, we store the input[0] to in0 @@ -119,12 +116,12 @@ void SumCompute(const SumParam ¶m) { if (sel_row.rows().size() == 0) { continue; } - PADDLE_MOBILE_ENFORCE(out->height() == sel_row.height()); + PADDLE_MOBILE_ENFORCE(out->height() == sel_row.height(), + "seletrows height != outheight"); functor(sel_row, offset, out); offset += sel_row.value().numel(); } } else if (outvar->IsType()) { - DLOG << "3:"; auto &out_array = *outvar->GetMutable(); for (size_t i = in_place ? 1 : 0; i < inputsvars.size(); ++i) { PADDLE_MOBILE_ENFORCE(inputsvars[i]->IsType(), @@ -140,7 +137,8 @@ void SumCompute(const SumParam ¶m) { framework::TensorCopy((*in_array)[i], &out_array[i]); out_array[i].set_lod((*in_array)[i].lod()); } else { - PADDLE_MOBILE_ENFORCE(out_array[i].lod() == (*in_array)[i].lod()); + PADDLE_MOBILE_ENFORCE(out_array[i].lod() == (*in_array)[i].lod(), + "outLod != inLod"); auto *inptr = (*in_array)[i].data(); auto *outptr = out_array[i].data(); @@ -152,9 +150,7 @@ void SumCompute(const SumParam ¶m) { } } } else { - DLOG << "2:"; if (outvar->IsType()) { - DLOG << "3: "; } PADDLE_MOBILE_THROW_EXCEPTION( "Unexpected branch, output variable type is %s", outvar->Type().name()); diff --git a/src/operators/kernel/elementwise_sub_kernel.h b/src/operators/kernel/elementwise_sub_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9516dcbd3de09debe233571eb5f60b3b8b19a2fa --- /dev/null +++ b/src/operators/kernel/elementwise_sub_kernel.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef ELEMENTWISEADD_OP + +#pragma once + +#include "framework/operator.h" +#include "operators/math/elementwise_op_function.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class ElementwiseSubKernel + : public framework::OpKernelBase> { + public: + void Compute(const ElementwiseSubParam ¶m) const; + bool Init(ElementwiseSubParam *param); +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/fpga/fc_relu_kernel.cpp b/src/operators/kernel/fpga/fc_relu_kernel.cpp index 904dd8a1da9e67d0c1283806e766d3a25dc27309..7c7bceaaee82617122da9c0fd2a5fa6b688f1153 100644 --- a/src/operators/kernel/fpga/fc_relu_kernel.cpp +++ b/src/operators/kernel/fpga/fc_relu_kernel.cpp @@ -44,6 +44,7 @@ bool FusionFcReluKernel::Init(FusionFcReluParam *param) { int width = (uint32_t)input_x->dims()[3]; int filter_channel = chw / height / width; + out->Resize(framework::make_ddim({1, channel, 1, 1})); filter->Resize(framework::make_ddim({num, filter_channel, height, width})); float max_value = fpga::filter_find_max(filter); fpga::format_fc_filter(filter, max_value); diff --git a/src/operators/kernel/fpga/fusion_fc_kernel.cpp b/src/operators/kernel/fpga/fusion_fc_kernel.cpp index 46dae1b2a076add9f17e4e5bc6d3a99ad583fb50..d543e1ea46bea09ee7331d03760633ee240454d5 100644 --- a/src/operators/kernel/fpga/fusion_fc_kernel.cpp +++ b/src/operators/kernel/fpga/fusion_fc_kernel.cpp @@ -45,6 +45,7 @@ bool FusionFcKernel::Init(FusionFcParam *param) { int width = (uint32_t)input_x->dims()[3]; int filter_channel = chw / height / width; + out->Resize(framework::make_ddim({1, channel, 1, 1})); filter->Resize(framework::make_ddim({num, filter_channel, height, width})); float max_value = fpga::filter_find_max(filter); fpga::format_fc_filter(filter, max_value); diff --git a/src/operators/kernel/fpga/mul_kernel.cpp b/src/operators/kernel/fpga/mul_kernel.cpp index 07aa4bcc43d28805ab0660bf89149c5ec5f1c732..9e282bd27b744cb48fccdc8e4602ae2fc9a1ad79 100644 --- a/src/operators/kernel/fpga/mul_kernel.cpp +++ b/src/operators/kernel/fpga/mul_kernel.cpp @@ -44,6 +44,7 @@ bool MulKernel::Init(MulParam *param) { int width = (uint32_t)input_x->dims()[3]; int filter_channel = chw / height / width; + out->Resize(framework::make_ddim({1, channel, 1, 1})); filter->Resize(framework::make_ddim({num, filter_channel, height, width})); float max_value = fpga::filter_find_max(filter); fpga::format_fc_filter(filter, max_value); diff --git a/src/operators/kernel/fpga/softmax_kernel.cpp b/src/operators/kernel/fpga/softmax_kernel.cpp index dba555708f505eb9bdf81d6f4487227c88f0a616..e36db57f4b4f18712df50b2b132cdd1032a41921 100644 --- a/src/operators/kernel/fpga/softmax_kernel.cpp +++ b/src/operators/kernel/fpga/softmax_kernel.cpp @@ -27,7 +27,7 @@ bool SoftmaxKernel::Init(SoftmaxParam *param) { auto input = const_cast(param->InputX()); auto input_ptr = input->data(); auto float_input = new Tensor; - float_input->mutable_data(input->dims()); + float_input->mutable_data({1, input->dims()[1]}); fpga::format_fp32_ofm(float_input); fpga::BypassArgs args = {fpga::DATA_TYPE_FP16}; @@ -56,7 +56,6 @@ void SoftmaxKernel::Compute( fpga::fpga_invalidate( (void *)in_x->data(), // NOLINT fpga::get_align_image_cw(in_x->dims()[1]) * sizeof(float)); - math::SoftmaxFuntor()(in_x, out); fpga::fpga_flush(out->data(), out->memory_size()); } diff --git a/src/operators/math/gpc.cpp b/src/operators/math/gpc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6b7700081a2ab6cb11187fad898e944390217db3 --- /dev/null +++ b/src/operators/math/gpc.cpp @@ -0,0 +1,2142 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef MULTICLASSNMS_OP + +#include "operators/math/gpc.h" + +namespace gpc { + +typedef struct lmt_shape { /* Local minima table */ + double y; /* Y coordinate at local minimum */ + edge_node *first_bound; /* Pointer to bound list */ + struct lmt_shape *next; /* Pointer to next local minimum */ +} lmt_node; + +typedef struct sbt_t_shape { /* Scanbeam tree */ + double y; /* Scanbeam node y value */ + struct sbt_t_shape *less; /* Pointer to nodes with lower y */ + struct sbt_t_shape *more; /* Pointer to nodes with higher y */ +} sb_tree; + +typedef struct it_shape { /* Intersection table */ + edge_node *ie[2]; /* Intersecting edge (bundle) pair */ + gpc_vertex point; /* Point of intersection */ + struct it_shape *next; /* The next intersection table node */ +} it_node; + +typedef struct st_shape { /* Sorted edge table */ + edge_node *edge; /* Pointer to AET edge */ + double xb; /* Scanbeam bottom x coordinate */ + double xt; /* Scanbeam top x coordinate */ + double dx; /* Change in x for a unit y increase */ + struct st_shape *prev; /* Previous edge in sorted list */ +} st_node; + +typedef struct bbox_shape { /* Contour axis-aligned bounding box */ + double xmin; /* Minimum x coordinate */ + double ymin; /* Minimum y coordinate */ + double xmax; /* Maximum x coordinate */ + double ymax; /* Maximum y coordinate */ +} bbox; + +/* +=========================================================================== + Global Data +=========================================================================== +*/ + +/* Horizontal edge state transitions within scanbeam boundary */ +const h_state next_h_state[3][6] = { + /* ABOVE BELOW CROSS */ + /* L R L R L R */ + /* NH */ + {BH, TH, TH, BH, NH, NH}, + /* BH */ + {NH, NH, NH, NH, TH, TH}, + /* TH */ + {NH, NH, NH, NH, BH, BH}}; + +/* +=========================================================================== + Private Functions +=========================================================================== +*/ + +static void reset_it(it_node **it) { + it_node *itn; + + while (*it) { + itn = (*it)->next; + gpc_free(*it); + *it = itn; + } +} + +static void reset_lmt(lmt_node **lmt) { + lmt_node *lmtn; + + while (*lmt) { + lmtn = (*lmt)->next; + gpc_free(*lmt); + *lmt = lmtn; + } +} + +static void insert_bound(edge_node **b, edge_node *e) { + edge_node *existing_bound = NULL; + + if (!*b) { + /* Link node e to the tail of the list */ + *b = e; + } else { + /* Do primary sort on the x field */ + if (e[0].bot.x < (*b)[0].bot.x) { + /* Insert a new node mid-list */ + existing_bound = *b; + *b = e; + (*b)->next_bound = existing_bound; + } else { + if (e[0].bot.x == (*b)[0].bot.x) { + /* Do secondary sort on the dx field */ + if (e[0].dx < (*b)[0].dx) { + /* Insert a new node mid-list */ + existing_bound = *b; + *b = e; + (*b)->next_bound = existing_bound; + } else { + /* Head further down the list */ + insert_bound(&((*b)->next_bound), e); + } + } else { + /* Head further down the list */ + insert_bound(&((*b)->next_bound), e); + } + } + } +} + +static edge_node **bound_list(lmt_node **lmt, double y) { + lmt_node *existing_node; + + if (!*lmt) { + /* Add node onto the tail end of the LMT */ + gpc_malloc(*lmt, sizeof(lmt_node), + const_cast("LMT insertion")); + (*lmt)->y = y; + (*lmt)->first_bound = NULL; + (*lmt)->next = NULL; + return &((*lmt)->first_bound); + } else if (y < (*lmt)->y) { + /* Insert a new LMT node before the current node */ + existing_node = *lmt; + gpc_malloc(*lmt, sizeof(lmt_node), + const_cast("LMT insertion")); + (*lmt)->y = y; + (*lmt)->first_bound = NULL; + (*lmt)->next = existing_node; + return &((*lmt)->first_bound); + } else { + if (y > (*lmt)->y) { + /* Head further up the LMT */ + return bound_list(&((*lmt)->next), y); + } else { + /* Use this existing LMT node */ + return &((*lmt)->first_bound); + } + } +} + +static void add_to_sbtree(int *entries, sb_tree **sbtree, double y) { + if (!*sbtree) { + /* Add a new tree node here */ + gpc_malloc(*sbtree, sizeof(sb_tree), + const_cast("scanbeam tree insertion")); + (*sbtree)->y = y; + (*sbtree)->less = NULL; + (*sbtree)->more = NULL; + (*entries)++; + } else { + if ((*sbtree)->y > y) { + /* Head into the 'less' sub-tree */ + add_to_sbtree(entries, &((*sbtree)->less), y); + } else { + if ((*sbtree)->y < y) { + /* Head into the 'more' sub-tree */ + add_to_sbtree(entries, &((*sbtree)->more), y); + } + } + } +} + +static void build_sbt(int *entries, double *sbt, sb_tree *sbtree) { + if (sbtree->less) { + build_sbt(entries, sbt, sbtree->less); + } + sbt[*entries] = sbtree->y; + (*entries)++; + if (sbtree->more) { + build_sbt(entries, sbt, sbtree->more); + } +} + +static void free_sbtree(sb_tree **sbtree) { + if (*sbtree) { + free_sbtree(&((*sbtree)->less)); + free_sbtree(&((*sbtree)->more)); + gpc_free(*sbtree); + } +} + +static int count_optimal_vertices(gpc_vertex_list c) { + int result = 0; + int i = 0; + + /* Ignore non-contributing contours */ + if (c.num_vertices > 0) { + for (i = 0; i < c.num_vertices; i++) { + /* Ignore superfluous vertices embedded in horizontal edges */ + if (gpc_optimal(c.vertex, i, c.num_vertices)) { + result++; + } + } + } + return result; +} + +static edge_node *build_lmt(lmt_node **lmt, sb_tree **sbtree, int *sbt_entries, + gpc_polygon *p, int type, gpc_op op) { + int c = 0; + int i = 0; + int min = 0; + int max = 0; + int num_edges = 0; + int v = 0; + int num_vertices = 0; + int total_vertices = 0; + int e_index = 0; + edge_node *e = NULL; + edge_node *edge_table = NULL; + + for (c = 0; c < p->num_contours; c++) { + total_vertices += count_optimal_vertices(p->contour[c]); + } + + /* Create the entire input polygon edge table in one go */ + gpc_malloc(edge_table, total_vertices * sizeof(edge_node), + const_cast("edge table creation")); + + for (c = 0; c < p->num_contours; c++) { + if (p->contour[c].num_vertices < 0) { + /* Ignore the non-contributing contour and repair the vertex count */ + p->contour[c].num_vertices = -p->contour[c].num_vertices; + } else { + /* Perform contour optimisation */ + num_vertices = 0; + for (i = 0; i < p->contour[c].num_vertices; i++) { + if (gpc_optimal(p->contour[c].vertex, i, p->contour[c].num_vertices)) { + edge_table[num_vertices].vertex.x = p->contour[c].vertex[i].x; + edge_table[num_vertices].vertex.y = p->contour[c].vertex[i].y; + + /* Record vertex in the scanbeam table */ + add_to_sbtree(sbt_entries, sbtree, edge_table[num_vertices].vertex.y); + + num_vertices++; + } + } + + /* Do the contour forward pass */ + for (min = 0; min < num_vertices; min++) { + /* If a forward local minimum... */ + if (gpc_fwd_min(edge_table, min, num_vertices)) { + /* Search for the next local maximum... */ + num_edges = 1; + max = gpc_next_index(min, num_vertices); + while (gpc_not_fmax(edge_table, max, num_vertices)) { + num_edges++; + max = gpc_next_index(max, num_vertices); + } + + /* Build the next edge list */ + e = &edge_table[e_index]; + e_index += num_edges; + v = min; + e[0].bstate[BELOW] = UNBUNDLED; + e[0].bundle[BELOW][CLIP] = 0; + e[0].bundle[BELOW][SUBJ] = 0; + for (i = 0; i < num_edges; i++) { + e[i].xb = edge_table[v].vertex.x; + e[i].bot.x = edge_table[v].vertex.x; + e[i].bot.y = edge_table[v].vertex.y; + + v = gpc_next_index(v, num_vertices); + + e[i].top.x = edge_table[v].vertex.x; + e[i].top.y = edge_table[v].vertex.y; + e[i].dx = (edge_table[v].vertex.x - e[i].bot.x) / + (e[i].top.y - e[i].bot.y); + e[i].type = type; + e[i].outp[ABOVE] = NULL; + e[i].outp[BELOW] = NULL; + e[i].next = NULL; + e[i].prev = NULL; + e[i].succ = + ((num_edges > 1) && (i < (num_edges - 1))) ? &(e[i + 1]) : NULL; + e[i].pred = ((num_edges > 1) && (i > 0)) ? &(e[i - 1]) : NULL; + e[i].next_bound = NULL; + e[i].bside[CLIP] = (op == GPC_DIFF) ? RIGHT : LEFT; + e[i].bside[SUBJ] = LEFT; + } + insert_bound(bound_list(lmt, edge_table[min].vertex.y), e); + } + } + + /* Do the contour reverse pass */ + for (min = 0; min < num_vertices; min++) { + /* If a reverse local minimum... */ + if (gpc_rev_min(edge_table, min, num_vertices)) { + /* Search for the previous local maximum... */ + num_edges = 1; + max = gpc_prev_index(min, num_vertices); + while (gpc_not_rmax(edge_table, max, num_vertices)) { + num_edges++; + max = gpc_prev_index(max, num_vertices); + } + + /* Build the previous edge list */ + e = &edge_table[e_index]; + e_index += num_edges; + v = min; + e[0].bstate[BELOW] = UNBUNDLED; + e[0].bundle[BELOW][CLIP] = 0; + e[0].bundle[BELOW][SUBJ] = 0; + for (i = 0; i < num_edges; i++) { + e[i].xb = edge_table[v].vertex.x; + e[i].bot.x = edge_table[v].vertex.x; + e[i].bot.y = edge_table[v].vertex.y; + + v = gpc_prev_index(v, num_vertices); + + e[i].top.x = edge_table[v].vertex.x; + e[i].top.y = edge_table[v].vertex.y; + e[i].dx = (edge_table[v].vertex.x - e[i].bot.x) / + (e[i].top.y - e[i].bot.y); + e[i].type = type; + e[i].outp[ABOVE] = NULL; + e[i].outp[BELOW] = NULL; + e[i].next = NULL; + e[i].prev = NULL; + e[i].succ = + ((num_edges > 1) && (i < (num_edges - 1))) ? &(e[i + 1]) : NULL; + e[i].pred = ((num_edges > 1) && (i > 0)) ? &(e[i - 1]) : NULL; + e[i].next_bound = NULL; + e[i].bside[CLIP] = (op == GPC_DIFF) ? RIGHT : LEFT; + e[i].bside[SUBJ] = LEFT; + } + insert_bound(bound_list(lmt, edge_table[min].vertex.y), e); + } + } + } + } + return edge_table; +} // NOLINT + +static void add_edge_to_aet(edge_node **aet, edge_node *edge, edge_node *prev) { + if (!*aet) { + /* Append edge onto the tail end of the AET */ + *aet = edge; + edge->prev = prev; + edge->next = NULL; + } else { + /* Do primary sort on the xb field */ + if (edge->xb < (*aet)->xb) { + /* Insert edge here (before the AET edge) */ + edge->prev = prev; + edge->next = *aet; + (*aet)->prev = edge; + *aet = edge; + } else { + if (edge->xb == (*aet)->xb) { + /* Do secondary sort on the dx field */ + if (edge->dx < (*aet)->dx) { + /* Insert edge here (before the AET edge) */ + edge->prev = prev; + edge->next = *aet; + (*aet)->prev = edge; + *aet = edge; + } else { + /* Head further into the AET */ + add_edge_to_aet(&((*aet)->next), edge, *aet); + } + } else { + /* Head further into the AET */ + add_edge_to_aet(&((*aet)->next), edge, *aet); + } + } + } +} + +static void add_intersection(it_node **it, edge_node *edge0, edge_node *edge1, + double x, double y) { + it_node *existing_node; + + if (!*it) { + /* Append a new node to the tail of the list */ + gpc_malloc(*it, sizeof(it_node), + const_cast("IT insertion")); + (*it)->ie[0] = edge0; + (*it)->ie[1] = edge1; + (*it)->point.x = x; + (*it)->point.y = y; + (*it)->next = NULL; + } else { + if ((*it)->point.y > y) { + /* Insert a new node mid-list */ + existing_node = *it; + gpc_malloc(*it, sizeof(it_node), + const_cast("IT insertion")); + (*it)->ie[0] = edge0; + (*it)->ie[1] = edge1; + (*it)->point.x = x; + (*it)->point.y = y; + (*it)->next = existing_node; + } else { + /* Head further down the list */ + add_intersection(&((*it)->next), edge0, edge1, x, y); + } + } +} + +static void add_st_edge(st_node **st, it_node **it, edge_node *edge, + double dy) { + st_node *existing_node; + double den = 0.0; + double r = 0.0; + double x = 0.0; + double y = 0.0; + + if (!*st) { + /* Append edge onto the tail end of the ST */ + gpc_malloc(*st, sizeof(st_node), + const_cast("ST insertion")); + (*st)->edge = edge; + (*st)->xb = edge->xb; + (*st)->xt = edge->xt; + (*st)->dx = edge->dx; + (*st)->prev = NULL; + } else { + den = ((*st)->xt - (*st)->xb) - (edge->xt - edge->xb); + + /* If new edge and ST edge don't cross */ + if ((edge->xt >= (*st)->xt) || (edge->dx == (*st)->dx) || + (fabs(den) <= DBL_EPSILON)) { + /* No intersection - insert edge here (before the ST edge) */ + existing_node = *st; + gpc_malloc(*st, sizeof(st_node), + const_cast("ST insertion")); + (*st)->edge = edge; + (*st)->xb = edge->xb; + (*st)->xt = edge->xt; + (*st)->dx = edge->dx; + (*st)->prev = existing_node; + } else { + /* Compute intersection between new edge and ST edge */ + r = (edge->xb - (*st)->xb) / den; + x = (*st)->xb + r * ((*st)->xt - (*st)->xb); + y = r * dy; + + /* Insert the edge pointers and the intersection point in the IT */ + add_intersection(it, (*st)->edge, edge, x, y); + + /* Head further into the ST */ + add_st_edge(&((*st)->prev), it, edge, dy); + } + } +} + +static void build_intersection_table(it_node **it, edge_node *aet, double dy) { + st_node *st; + st_node *stp; + edge_node *edge = NULL; + + /* Build intersection table for the current scanbeam */ + reset_it(it); + st = NULL; + + /* Process each AET edge */ + for (edge = aet; edge; edge = edge->next) { + if ((edge->bstate[ABOVE] == BUNDLE_HEAD) || edge->bundle[ABOVE][CLIP] || + edge->bundle[ABOVE][SUBJ]) { + add_st_edge(&st, it, edge, dy); + } + } + + /* Free the sorted edge table */ + while (st) { + stp = st->prev; + gpc_free(st); + st = stp; + } +} + +static int count_contours(polygon_node *polygon) { + int nc = 0; + int nv = 0; + vertex_node *v = NULL; + vertex_node *nextv = NULL; + + for (nc = 0; polygon; polygon = polygon->next) { + if (polygon->active) { + /* Count the vertices in the current contour */ + nv = 0; + for (v = polygon->proxy->v[LEFT]; v; v = v->next) { + nv++; + } + + /* Record valid vertex counts in the active field */ + if (nv > 2) { + polygon->active = nv; + nc++; + } else { + /* Invalid contour: just free the heap */ + for (v = polygon->proxy->v[LEFT]; v; v = nextv) { + nextv = v->next; + gpc_free(v); + } + polygon->active = 0; + } + } + } + return nc; +} + +static void add_left(polygon_node *p, double x, double y) { + vertex_node *nv = NULL; + + /* Create a new vertex node and set its fields */ + gpc_malloc(nv, sizeof(vertex_node), + const_cast("vertex node creation")); + nv->x = x; + nv->y = y; + + /* Add vertex nv to the left end of the polygon's vertex list */ + nv->next = p->proxy->v[LEFT]; + + /* Update proxy->[LEFT] to point to nv */ + p->proxy->v[LEFT] = nv; +} + +static void merge_left(polygon_node *p, polygon_node *q, polygon_node *list) { + polygon_node *target = NULL; + + /* Label contour as a hole */ + q->proxy->hole = 1; + + if (p->proxy != q->proxy) { + /* Assign p's vertex list to the left end of q's list */ + p->proxy->v[RIGHT]->next = q->proxy->v[LEFT]; + q->proxy->v[LEFT] = p->proxy->v[LEFT]; + + /* Redirect any p->proxy references to q->proxy */ + + for (target = p->proxy; list; list = list->next) { + if (list->proxy == target) { + list->active = 0; + list->proxy = q->proxy; + } + } + } +} + +static void add_right(polygon_node *p, double x, double y) { + vertex_node *nv = NULL; + + /* Create a new vertex node and set its fields */ + gpc_malloc(nv, sizeof(vertex_node), + const_cast("vertex node creation")); + nv->x = x; + nv->y = y; + nv->next = NULL; + + /* Add vertex nv to the right end of the polygon's vertex list */ + p->proxy->v[RIGHT]->next = nv; + + /* Update proxy->v[RIGHT] to point to nv */ + p->proxy->v[RIGHT] = nv; +} + +static void merge_right(polygon_node *p, polygon_node *q, polygon_node *list) { + polygon_node *target = NULL; + + /* Label contour as external */ + q->proxy->hole = 0; + + if (p->proxy != q->proxy) { + /* Assign p's vertex list to the right end of q's list */ + q->proxy->v[RIGHT]->next = p->proxy->v[LEFT]; + q->proxy->v[RIGHT] = p->proxy->v[RIGHT]; + + /* Redirect any p->proxy references to q->proxy */ + for (target = p->proxy; list; list = list->next) { + if (list->proxy == target) { + list->active = 0; + list->proxy = q->proxy; + } + } + } +} + +static void add_local_min(polygon_node **p, edge_node *edge, double x, + double y) { + polygon_node *existing_min = NULL; + vertex_node *nv = NULL; + + existing_min = *p; + + gpc_malloc(*p, sizeof(polygon_node), + const_cast("polygon node creation")); + + /* Create a new vertex node and set its fields */ + gpc_malloc(nv, sizeof(vertex_node), + const_cast("vertex node creation")); + nv->x = x; + nv->y = y; + nv->next = NULL; + + /* Initialise proxy to point to p itself */ + (*p)->proxy = (*p); + (*p)->active = 1; + (*p)->next = existing_min; + + /* Make v[LEFT] and v[RIGHT] point to new vertex nv */ + (*p)->v[LEFT] = nv; + (*p)->v[RIGHT] = nv; + + /* Assign polygon p to the edge */ + edge->outp[ABOVE] = *p; +} + +static int count_tristrips(polygon_node *tn) { + int total = 0; + + for (total = 0; tn; tn = tn->next) { + if (tn->active > 2) { + total++; + } + } + return total; +} + +void add_vertex(vertex_node **t, double x, double y) { + if (!(*t)) { + gpc_malloc(*t, sizeof(vertex_node), + const_cast("tristrip vertex creation")); + (*t)->x = x; + (*t)->y = y; + (*t)->next = NULL; + } else { + /* Head further down the list */ + add_vertex(&((*t)->next), x, y); + } +} + +void gpc_vertex_create(edge_node *e, int p, int s, double x, double y) { + add_vertex(&(e->outp[p]->v[s]), x, y); + e->outp[p]->active++; +} + +static void new_tristrip(polygon_node **tn, edge_node *edge, double x, + double y) { + if (!(*tn)) { + gpc_malloc(*tn, sizeof(polygon_node), + const_cast("tristrip node creation")); + (*tn)->next = NULL; + (*tn)->v[LEFT] = NULL; + (*tn)->v[RIGHT] = NULL; + (*tn)->active = 1; + add_vertex(&((*tn)->v[LEFT]), x, y); + edge->outp[ABOVE] = *tn; + } else { + /* Head further down the list */ + new_tristrip(&((*tn)->next), edge, x, y); + } +} + +static bbox *create_contour_bboxes(gpc_polygon *p) { + bbox *box; + int c = 0; + int v = 0; + + gpc_malloc(box, p->num_contours * sizeof(bbox), + const_cast("Bounding box creation")); + + /* Construct contour bounding boxes */ + for (c = 0; c < p->num_contours; c++) { + /* Initialise bounding box extent */ + box[c].xmin = DBL_MAX; + box[c].ymin = DBL_MAX; + box[c].xmax = -DBL_MAX; + box[c].ymax = -DBL_MAX; + + for (v = 0; v < p->contour[c].num_vertices; v++) { + /* Adjust bounding box */ + if (p->contour[c].vertex[v].x < box[c].xmin) { + box[c].xmin = p->contour[c].vertex[v].x; + } + if (p->contour[c].vertex[v].y < box[c].ymin) { + box[c].ymin = p->contour[c].vertex[v].y; + } + if (p->contour[c].vertex[v].x > box[c].xmax) { + box[c].xmax = p->contour[c].vertex[v].x; + } + if (p->contour[c].vertex[v].y > box[c].ymax) { + box[c].ymax = p->contour[c].vertex[v].y; + } + } + } + return box; +} + +static void minimax_test(gpc_polygon *subj, gpc_polygon *clip, gpc_op op) { + bbox *s_bbox; + bbox *c_bbox; + int s = 0; + int c = 0; + int *o_table = NULL; + int overlap = 0; + + s_bbox = create_contour_bboxes(subj); + c_bbox = create_contour_bboxes(clip); + + gpc_malloc(o_table, + subj->num_contours * clip->num_contours * sizeof(int), + const_cast("overlap table creation")); + + /* Check all subject contour bounding boxes against clip boxes */ + for (s = 0; s < subj->num_contours; s++) { + for (c = 0; c < clip->num_contours; c++) { + o_table[c * subj->num_contours + s] = + (!((s_bbox[s].xmax < c_bbox[c].xmin) || + (s_bbox[s].xmin > c_bbox[c].xmax))) && + (!((s_bbox[s].ymax < c_bbox[c].ymin) || + (s_bbox[s].ymin > c_bbox[c].ymax))); + } + } + + /* For each clip contour, search for any subject contour overlaps */ + for (c = 0; c < clip->num_contours; c++) { + overlap = 0; + for (s = 0; (!overlap) && (s < subj->num_contours); s++) { + overlap = o_table[c * subj->num_contours + s]; + } + + if (!overlap) { + /* Flag non contributing status by negating vertex count */ + clip->contour[c].num_vertices = -clip->contour[c].num_vertices; + } + } + + if (op == GPC_INT) { + /* For each subject contour, search for any clip contour overlaps */ + for (s = 0; s < subj->num_contours; s++) { + overlap = 0; + for (c = 0; (!overlap) && (c < clip->num_contours); c++) { + overlap = o_table[c * subj->num_contours + s]; + } + + if (!overlap) { + /* Flag non contributing status by negating vertex count */ + subj->contour[s].num_vertices = -subj->contour[s].num_vertices; + } + } + } + + gpc_free(s_bbox); + gpc_free(c_bbox); + gpc_free(o_table); +} + +/* +=========================================================================== + Public Functions +=========================================================================== +*/ + +void gpc_free_polygon(gpc_polygon *p) { + int c = 0; + + for (c = 0; c < p->num_contours; c++) { + gpc_free(p->contour[c].vertex); + } + gpc_free(p->hole); + gpc_free(p->contour); + p->num_contours = 0; +} + +void gpc_add_contour(gpc_polygon *p, gpc_vertex_list *new_contour, int hole) { + int *extended_hole = NULL; + int c = 0; + int v = 0; + gpc_vertex_list *extended_contour = NULL; + + /* Create an extended hole array */ + gpc_malloc(extended_hole, (p->num_contours + 1) * sizeof(int), + const_cast("contour hole addition")); + + /* Create an extended contour array */ + gpc_malloc(extended_contour, + (p->num_contours + 1) * sizeof(gpc_vertex_list), + const_cast("contour addition")); + + /* Copy the old contour and hole data into the extended arrays */ + for (c = 0; c < p->num_contours; c++) { + extended_hole[c] = p->hole[c]; + extended_contour[c] = p->contour[c]; + } + + /* Copy the new contour and hole onto the end of the extended arrays */ + c = p->num_contours; + extended_hole[c] = hole; + extended_contour[c].num_vertices = new_contour->num_vertices; + gpc_malloc(extended_contour[c].vertex, + new_contour->num_vertices * sizeof(gpc_vertex), + const_cast("contour addition")); + for (v = 0; v < new_contour->num_vertices; v++) { + extended_contour[c].vertex[v] = new_contour->vertex[v]; + } + + /* Dispose of the old contour */ + gpc_free(p->contour); + gpc_free(p->hole); + + /* Update the polygon information */ + p->num_contours++; + p->hole = extended_hole; + p->contour = extended_contour; +} + +// gpc_polygon_clip +void gpc_polygon_clip(gpc_op op, gpc_polygon *subj, gpc_polygon *clip, + gpc_polygon *result) { + sb_tree *sbtree = NULL; + it_node *it = NULL; + it_node *intersect = NULL; + edge_node *edge = NULL; + edge_node *prev_edge = NULL; + edge_node *next_edge = NULL; + edge_node *succ_edge = NULL; + edge_node *e0 = NULL; + edge_node *e1 = NULL; + edge_node *aet = NULL; + edge_node *c_heap = NULL; + edge_node *s_heap = NULL; + lmt_node *lmt = NULL; + lmt_node *local_min = NULL; + polygon_node *out_poly = NULL; + polygon_node *p = NULL; + polygon_node *q = NULL; + polygon_node *poly = NULL; + polygon_node *npoly = NULL; + polygon_node *cf = NULL; + vertex_node *vtx = NULL; + vertex_node *nv = NULL; + h_state horiz[2]; + int in[2]; + int exists[2]; + int parity[2] = {LEFT, LEFT}; + int c = 0; + int v = 0; + int contributing = 0; + int search = 0; + int scanbeam = 0; + int sbt_entries = 0; + int vclass = 0; + int bl = 0; + int br = 0; + int tl = 0; + int tr = 0; + double *sbt = NULL; + double xb = 0.0; + double px = 0.0; + double yb = 0.0; + double yt = 0.0; + double dy = 0.0; + double ix = 0.0; + double iy = 0.0; + + /* Test for trivial NULL result cases */ + if (((subj->num_contours == 0) && (clip->num_contours == 0)) || + ((subj->num_contours == 0) && ((op == GPC_INT) || (op == GPC_DIFF))) || + ((clip->num_contours == 0) && (op == GPC_INT))) { + result->num_contours = 0; + result->hole = NULL; + result->contour = NULL; + return; + } + /* Identify potentialy contributing contours */ + if (((op == GPC_INT) || (op == GPC_DIFF)) && (subj->num_contours > 0) && + (clip->num_contours > 0)) { + minimax_test(subj, clip, op); + } + /* Build LMT */ + if (subj->num_contours > 0) { + s_heap = build_lmt(&lmt, &sbtree, &sbt_entries, subj, SUBJ, op); + } + if (clip->num_contours > 0) { + c_heap = build_lmt(&lmt, &sbtree, &sbt_entries, clip, CLIP, op); + } + /* Return a NULL result if no contours contribute */ + if (lmt == NULL) { + result->num_contours = 0; + result->hole = NULL; + result->contour = NULL; + reset_lmt(&lmt); + gpc_free(s_heap); + gpc_free(c_heap); + return; + } + + /* Build scanbeam table from scanbeam tree */ + gpc_malloc(sbt, sbt_entries * sizeof(double), + const_cast("sbt creation")); + build_sbt(&scanbeam, sbt, sbtree); + scanbeam = 0; + free_sbtree(&sbtree); + /* Allow pointer re-use without causing memory leak */ + if (subj == result) { + gpc_free_polygon(subj); + } + if (clip == result) { + gpc_free_polygon(clip); + } + /* Invert clip polygon for difference operation */ + if (op == GPC_DIFF) { + parity[CLIP] = RIGHT; + } + local_min = lmt; + + // Process each scanbeam + while (scanbeam < sbt_entries) { + /* Set yb and yt to the bottom and top of the scanbeam */ + yb = sbt[scanbeam++]; + if (scanbeam < sbt_entries) { + yt = sbt[scanbeam]; + dy = yt - yb; + } + /* === SCANBEAM BOUNDARY PROCESSING ================================ */ + /* If LMT node corresponding to yb exists */ + if (local_min) { + if (local_min->y == yb) { + /* Add edges starting at this local minimum to the AET */ + for (edge = local_min->first_bound; edge; edge = edge->next_bound) { + add_edge_to_aet(&aet, edge, NULL); + } + local_min = local_min->next; + } + } + /* Set dummy previous x value */ + px = -DBL_MAX; + /* Create bundles within AET */ + e0 = aet; + e1 = aet; + /* Set up bundle fields of first edge */ + aet->bundle[ABOVE][aet->type] = (aet->top.y != yb); + aet->bundle[ABOVE][!aet->type] = 0; + aet->bstate[ABOVE] = UNBUNDLED; + + for (next_edge = aet->next; next_edge; next_edge = next_edge->next) { + /* Set up bundle fields of next edge */ + next_edge->bundle[ABOVE][next_edge->type] = (next_edge->top.y != yb); + next_edge->bundle[ABOVE][!next_edge->type] = 0; + next_edge->bstate[ABOVE] = UNBUNDLED; + /* Bundle edges above the scanbeam boundary if they coincide */ + if (next_edge->bundle[ABOVE][next_edge->type]) { + if (gpc_eq(e0->xb, next_edge->xb) && gpc_eq(e0->dx, next_edge->dx) && + (e0->top.y != yb)) { + next_edge->bundle[ABOVE][next_edge->type] ^= + e0->bundle[ABOVE][next_edge->type]; + next_edge->bundle[ABOVE][!next_edge->type] = + e0->bundle[ABOVE][!next_edge->type]; + next_edge->bstate[ABOVE] = BUNDLE_HEAD; + e0->bundle[ABOVE][CLIP] = 0; + e0->bundle[ABOVE][SUBJ] = 0; + e0->bstate[ABOVE] = BUNDLE_TAIL; + } + e0 = next_edge; + } + } + horiz[CLIP] = NH; + horiz[SUBJ] = NH; + + // Process each edge at this scanbeam boundary + for (edge = aet; edge; edge = edge->next) { + exists[CLIP] = + edge->bundle[ABOVE][CLIP] + (edge->bundle[BELOW][CLIP] << 1); + exists[SUBJ] = + edge->bundle[ABOVE][SUBJ] + (edge->bundle[BELOW][SUBJ] << 1); + if (exists[CLIP] || exists[SUBJ]) { + /* Set bundle side */ + edge->bside[CLIP] = parity[CLIP]; + edge->bside[SUBJ] = parity[SUBJ]; + /* Determine contributing status and quadrant occupancies */ + switch (op) { + case GPC_DIFF: + case GPC_INT: + contributing = (exists[CLIP] && (parity[SUBJ] || horiz[SUBJ])) || + (exists[SUBJ] && (parity[CLIP] || horiz[CLIP])) || + (exists[CLIP] && exists[SUBJ] && + (parity[CLIP] == parity[SUBJ])); + br = (parity[CLIP]) && (parity[SUBJ]); + bl = (parity[CLIP] ^ edge->bundle[ABOVE][CLIP]) && + (parity[SUBJ] ^ edge->bundle[ABOVE][SUBJ]); + tr = (parity[CLIP] ^ (horiz[CLIP] != NH)) && + (parity[SUBJ] ^ (horiz[SUBJ] != NH)); + tl = (parity[CLIP] ^ (horiz[CLIP] != NH) ^ + edge->bundle[BELOW][CLIP]) && + (parity[SUBJ] ^ (horiz[SUBJ] != NH) ^ + edge->bundle[BELOW][SUBJ]); + break; + case GPC_XOR: + contributing = exists[CLIP] || exists[SUBJ]; + br = (parity[CLIP]) ^ (parity[SUBJ]); + bl = (parity[CLIP] ^ edge->bundle[ABOVE][CLIP]) ^ + (parity[SUBJ] ^ edge->bundle[ABOVE][SUBJ]); + tr = (parity[CLIP] ^ (horiz[CLIP] != NH)) ^ + (parity[SUBJ] ^ (horiz[SUBJ] != NH)); + tl = (parity[CLIP] ^ (horiz[CLIP] != NH) ^ + edge->bundle[BELOW][CLIP]) ^ + (parity[SUBJ] ^ (horiz[SUBJ] != NH) ^ + edge->bundle[BELOW][SUBJ]); + break; + case GPC_UNION: + contributing = (exists[CLIP] && (!parity[SUBJ] || horiz[SUBJ])) || + (exists[SUBJ] && (!parity[CLIP] || horiz[CLIP])) || + (exists[CLIP] && exists[SUBJ] && + (parity[CLIP] == parity[SUBJ])); + br = (parity[CLIP]) || (parity[SUBJ]); + bl = (parity[CLIP] ^ edge->bundle[ABOVE][CLIP]) || + (parity[SUBJ] ^ edge->bundle[ABOVE][SUBJ]); + tr = (parity[CLIP] ^ (horiz[CLIP] != NH)) || + (parity[SUBJ] ^ (horiz[SUBJ] != NH)); + tl = (parity[CLIP] ^ (horiz[CLIP] != NH) ^ + edge->bundle[BELOW][CLIP]) || + (parity[SUBJ] ^ (horiz[SUBJ] != NH) ^ + edge->bundle[BELOW][SUBJ]); + break; + } + // Update parity + parity[CLIP] ^= edge->bundle[ABOVE][CLIP]; + parity[SUBJ] ^= edge->bundle[ABOVE][SUBJ]; + /* Update horizontal state */ + if (exists[CLIP]) { + horiz[CLIP] = next_h_state[horiz[CLIP]] + [((exists[CLIP] - 1) << 1) + parity[CLIP]]; + } + if (exists[SUBJ]) { + horiz[SUBJ] = next_h_state[horiz[SUBJ]] + [((exists[SUBJ] - 1) << 1) + parity[SUBJ]]; + } + vclass = tr + (tl << 1) + (br << 2) + (bl << 3); + if (contributing) { + xb = edge->xb; + switch (vclass) { + case EMN: + case IMN: + add_local_min(&out_poly, edge, xb, yb); + px = xb; + cf = edge->outp[ABOVE]; + break; + case ERI: + if (xb != px) { + add_right(cf, xb, yb); + px = xb; + } + edge->outp[ABOVE] = cf; + cf = NULL; + break; + case ELI: + add_left(edge->outp[BELOW], xb, yb); + px = xb; + cf = edge->outp[BELOW]; + break; + case EMX: + if (xb != px) { + add_left(cf, xb, yb); + px = xb; + } + merge_right(cf, edge->outp[BELOW], out_poly); + cf = NULL; + break; + case ILI: + if (xb != px) { + add_left(cf, xb, yb); + px = xb; + } + edge->outp[ABOVE] = cf; + cf = NULL; + break; + case IRI: + add_right(edge->outp[BELOW], xb, yb); + px = xb; + cf = edge->outp[BELOW]; + edge->outp[BELOW] = NULL; + break; + case IMX: + if (xb != px) { + add_right(cf, xb, yb); + px = xb; + } + merge_left(cf, edge->outp[BELOW], out_poly); + cf = NULL; + edge->outp[BELOW] = NULL; + break; + case IMM: + if (xb != px) { + add_right(cf, xb, yb); + px = xb; + } + merge_left(cf, edge->outp[BELOW], out_poly); + edge->outp[BELOW] = NULL; + add_local_min(&out_poly, edge, xb, yb); + cf = edge->outp[ABOVE]; + break; + case EMM: + if (xb != px) { + add_left(cf, xb, yb); + px = xb; + } + merge_right(cf, edge->outp[BELOW], out_poly); + edge->outp[BELOW] = NULL; + add_local_min(&out_poly, edge, xb, yb); + cf = edge->outp[ABOVE]; + break; + case LED: + if (edge->bot.y == yb) { + add_left(edge->outp[BELOW], xb, yb); + } + edge->outp[ABOVE] = edge->outp[BELOW]; + px = xb; + break; + case RED: + if (edge->bot.y == yb) { + add_right(edge->outp[BELOW], xb, yb); + } + edge->outp[ABOVE] = edge->outp[BELOW]; + px = xb; + break; + default: + break; + } /* End of switch */ + } /* End of contributing conditional */ + } /* End of edge exists conditional */ + } // End of AET loop + + /* Delete terminating edges from the AET, otherwise compute xt */ + for (edge = aet; edge; edge = edge->next) { + if (edge->top.y == yb) { + prev_edge = edge->prev; + next_edge = edge->next; + if (prev_edge) { + prev_edge->next = next_edge; + } else { + aet = next_edge; + } + if (next_edge) { + next_edge->prev = prev_edge; + } + /* Copy bundle head state to the adjacent tail edge if required */ + if ((edge->bstate[BELOW] == BUNDLE_HEAD) && prev_edge) { + if (prev_edge->bstate[BELOW] == BUNDLE_TAIL) { + prev_edge->outp[BELOW] = edge->outp[BELOW]; + prev_edge->bstate[BELOW] = UNBUNDLED; + if (prev_edge->prev) { + if (prev_edge->prev->bstate[BELOW] == BUNDLE_TAIL) { + prev_edge->bstate[BELOW] = BUNDLE_HEAD; + } + } + } + } + } else { + if (edge->top.y == yt) { + edge->xt = edge->top.x; + } else { + edge->xt = edge->bot.x + edge->dx * (yt - edge->bot.y); + } + } + } + + if (scanbeam < sbt_entries) { + /* === SCANBEAM INTERIOR PROCESSING ============================== */ + build_intersection_table(&it, aet, dy); + /* Process each node in the intersection table */ + for (intersect = it; intersect; intersect = intersect->next) { + e0 = intersect->ie[0]; + e1 = intersect->ie[1]; + /* Only generate output for contributing intersections */ + if ((e0->bundle[ABOVE][CLIP] || e0->bundle[ABOVE][SUBJ]) && + (e1->bundle[ABOVE][CLIP] || e1->bundle[ABOVE][SUBJ])) { + p = e0->outp[ABOVE]; + q = e1->outp[ABOVE]; + ix = intersect->point.x; + iy = intersect->point.y + yb; + + in[CLIP] = (e0->bundle[ABOVE][CLIP] && !e0->bside[CLIP]) || + (e1->bundle[ABOVE][CLIP] && e1->bside[CLIP]) || + (!e0->bundle[ABOVE][CLIP] && !e1->bundle[ABOVE][CLIP] && + e0->bside[CLIP] && e1->bside[CLIP]); + in[SUBJ] = (e0->bundle[ABOVE][SUBJ] && !e0->bside[SUBJ]) || + (e1->bundle[ABOVE][SUBJ] && e1->bside[SUBJ]) || + (!e0->bundle[ABOVE][SUBJ] && !e1->bundle[ABOVE][SUBJ] && + e0->bside[SUBJ] && e1->bside[SUBJ]); + + // Determine quadrant occupancies + switch (op) { + case GPC_DIFF: + case GPC_INT: + tr = (in[CLIP]) && (in[SUBJ]); + tl = (in[CLIP] ^ e1->bundle[ABOVE][CLIP]) && + (in[SUBJ] ^ e1->bundle[ABOVE][SUBJ]); + br = (in[CLIP] ^ e0->bundle[ABOVE][CLIP]) && + (in[SUBJ] ^ e0->bundle[ABOVE][SUBJ]); + bl = (in[CLIP] ^ e1->bundle[ABOVE][CLIP] ^ + e0->bundle[ABOVE][CLIP]) && + (in[SUBJ] ^ e1->bundle[ABOVE][SUBJ] ^ + e0->bundle[ABOVE][SUBJ]); + break; + case GPC_XOR: + tr = (in[CLIP]) ^ (in[SUBJ]); + tl = (in[CLIP] ^ e1->bundle[ABOVE][CLIP]) ^ + (in[SUBJ] ^ e1->bundle[ABOVE][SUBJ]); + br = (in[CLIP] ^ e0->bundle[ABOVE][CLIP]) ^ + (in[SUBJ] ^ e0->bundle[ABOVE][SUBJ]); + bl = (in[CLIP] ^ e1->bundle[ABOVE][CLIP] ^ + e0->bundle[ABOVE][CLIP]) ^ + (in[SUBJ] ^ e1->bundle[ABOVE][SUBJ] ^ + e0->bundle[ABOVE][SUBJ]); + break; + case GPC_UNION: + tr = (in[CLIP]) || (in[SUBJ]); + tl = (in[CLIP] ^ e1->bundle[ABOVE][CLIP]) || + (in[SUBJ] ^ e1->bundle[ABOVE][SUBJ]); + br = (in[CLIP] ^ e0->bundle[ABOVE][CLIP]) || + (in[SUBJ] ^ e0->bundle[ABOVE][SUBJ]); + bl = (in[CLIP] ^ e1->bundle[ABOVE][CLIP] ^ + e0->bundle[ABOVE][CLIP]) || + (in[SUBJ] ^ e1->bundle[ABOVE][SUBJ] ^ + e0->bundle[ABOVE][SUBJ]); + break; + } + vclass = tr + (tl << 1) + (br << 2) + (bl << 3); + switch (vclass) { + case EMN: + add_local_min(&out_poly, e0, ix, iy); + e1->outp[ABOVE] = e0->outp[ABOVE]; + break; + case ERI: + if (p) { + add_right(p, ix, iy); + e1->outp[ABOVE] = p; + e0->outp[ABOVE] = NULL; + } + break; + case ELI: + if (q) { + add_left(q, ix, iy); + e0->outp[ABOVE] = q; + e1->outp[ABOVE] = NULL; + } + break; + case EMX: + if (p && q) { + add_left(p, ix, iy); + merge_right(p, q, out_poly); + e0->outp[ABOVE] = NULL; + e1->outp[ABOVE] = NULL; + } + break; + case IMN: + add_local_min(&out_poly, e0, ix, iy); + e1->outp[ABOVE] = e0->outp[ABOVE]; + break; + case ILI: + if (p) { + add_left(p, ix, iy); + e1->outp[ABOVE] = p; + e0->outp[ABOVE] = NULL; + } + break; + case IRI: + if (q) { + add_right(q, ix, iy); + e0->outp[ABOVE] = q; + e1->outp[ABOVE] = NULL; + } + break; + case IMX: + if (p && q) { + add_right(p, ix, iy); + merge_left(p, q, out_poly); + e0->outp[ABOVE] = NULL; + e1->outp[ABOVE] = NULL; + } + break; + case IMM: + if (p && q) { + add_right(p, ix, iy); + merge_left(p, q, out_poly); + add_local_min(&out_poly, e0, ix, iy); + e1->outp[ABOVE] = e0->outp[ABOVE]; + } + break; + case EMM: + if (p && q) { + add_left(p, ix, iy); + merge_right(p, q, out_poly); + add_local_min(&out_poly, e0, ix, iy); + e1->outp[ABOVE] = e0->outp[ABOVE]; + } + break; + default: + break; + } // End of switch + } /* End of contributing intersection conditional */ + + /* Swap bundle sides in response to edge crossing */ + if (e0->bundle[ABOVE][CLIP]) { + e1->bside[CLIP] = !e1->bside[CLIP]; + } + if (e1->bundle[ABOVE][CLIP]) { + e0->bside[CLIP] = !e0->bside[CLIP]; + } + if (e0->bundle[ABOVE][SUBJ]) { + e1->bside[SUBJ] = !e1->bside[SUBJ]; + } + if (e1->bundle[ABOVE][SUBJ]) { + e0->bside[SUBJ] = !e0->bside[SUBJ]; + } + + /* Swap e0 and e1 bundles in the AET */ + prev_edge = e0->prev; + next_edge = e1->next; + if (next_edge) { + next_edge->prev = e0; + } + if (e0->bstate[ABOVE] == BUNDLE_HEAD) { + search = 1; + while (search) { + prev_edge = prev_edge->prev; + if (prev_edge) { + if (prev_edge->bstate[ABOVE] != BUNDLE_TAIL) { + search = 0; + } + } else { + search = 0; + } + } + } + if (!prev_edge) { + aet->prev = e1; + e1->next = aet; + aet = e0->next; + } else { + prev_edge->next->prev = e1; + e1->next = prev_edge->next; + prev_edge->next = e0->next; + } + e0->next->prev = prev_edge; + e1->next->prev = e1; + e0->next = next_edge; + } /* End of IT loop*/ + + // Prepare for next scanbeam + for (edge = aet; edge; edge = next_edge) { + next_edge = edge->next; + succ_edge = edge->succ; + if ((edge->top.y == yt) && succ_edge) { + /* Replace AET edge by its successor */ + succ_edge->outp[BELOW] = edge->outp[ABOVE]; + succ_edge->bstate[BELOW] = edge->bstate[ABOVE]; + succ_edge->bundle[BELOW][CLIP] = edge->bundle[ABOVE][CLIP]; + succ_edge->bundle[BELOW][SUBJ] = edge->bundle[ABOVE][SUBJ]; + prev_edge = edge->prev; + if (prev_edge) { + prev_edge->next = succ_edge; + } else { + aet = succ_edge; + } + if (next_edge) { + next_edge->prev = succ_edge; + } + succ_edge->prev = prev_edge; + succ_edge->next = next_edge; + } else { + /* Update this edge */ + edge->outp[BELOW] = edge->outp[ABOVE]; + edge->bstate[BELOW] = edge->bstate[ABOVE]; + edge->bundle[BELOW][CLIP] = edge->bundle[ABOVE][CLIP]; + edge->bundle[BELOW][SUBJ] = edge->bundle[ABOVE][SUBJ]; + edge->xb = edge->xt; + } + edge->outp[ABOVE] = NULL; + } + } + } /* === END OF SCANBEAM PROCESSING ================================== */ + // Generate result polygon from out_poly + result->contour = NULL; + result->hole = NULL; + result->num_contours = count_contours(out_poly); + if (result->num_contours > 0) { + gpc_malloc(result->hole, result->num_contours * sizeof(int), + const_cast("hole flag table creation")); + gpc_malloc(result->contour, + result->num_contours * sizeof(gpc_vertex_list), + const_cast("contour creation")); + + c = 0; + for (poly = out_poly; poly; poly = npoly) { + npoly = poly->next; + if (poly->active) { + result->hole[c] = poly->proxy->hole; + result->contour[c].num_vertices = poly->active; + gpc_malloc( + result->contour[c].vertex, + result->contour[c].num_vertices * sizeof(gpc_vertex), + const_cast("vertex creation")); + + v = result->contour[c].num_vertices - 1; + for (vtx = poly->proxy->v[LEFT]; vtx; vtx = nv) { + nv = vtx->next; + result->contour[c].vertex[v].x = vtx->x; + result->contour[c].vertex[v].y = vtx->y; + gpc_free(vtx); + v--; + } + c++; + } + gpc_free(poly); + } + } else { + for (poly = out_poly; poly; poly = npoly) { + npoly = poly->next; + gpc_free(poly); + } + } + + // Tidy up + reset_it(&it); + reset_lmt(&lmt); + gpc_free(c_heap); + gpc_free(s_heap); + gpc_free(sbt); +} // NOLINT + +void gpc_free_tristrip(gpc_tristrip *t) { + int s = 0; + for (s = 0; s < t->num_strips; s++) { + gpc_free(t->strip[s].vertex); + } + gpc_free(t->strip); + t->num_strips = 0; +} + +void gpc_polygon_to_tristrip(gpc_polygon *s, gpc_tristrip *t) { + gpc_polygon c; + c.num_contours = 0; + c.hole = NULL; + c.contour = NULL; + gpc_tristrip_clip(GPC_DIFF, s, &c, t); +} + +// gpc_tristrip_clip +void gpc_tristrip_clip(gpc_op op, gpc_polygon *subj, gpc_polygon *clip, + gpc_tristrip *result) { + sb_tree *sbtree = NULL; + it_node *it = NULL; + it_node *intersect = NULL; + edge_node *edge = NULL; + edge_node *prev_edge = NULL; + edge_node *next_edge = NULL; + edge_node *succ_edge = NULL; + edge_node *e0 = NULL; + edge_node *e1 = NULL; + edge_node *aet = NULL; + edge_node *c_heap = NULL; + edge_node *s_heap = NULL; + edge_node *cf = NULL; + lmt_node *lmt = NULL; + lmt_node *local_min = NULL; + polygon_node *tlist = NULL; + polygon_node *tn = NULL; + polygon_node *tnn = NULL; + polygon_node *p = NULL; + polygon_node *q = NULL; + vertex_node *lt = NULL; + vertex_node *ltn = NULL; + vertex_node *rt = NULL; + vertex_node *rtn = NULL; + h_state horiz[2]; + vertex_type cft = NUL; + int in[2]; + int exists[2]; + int parity[2] = {LEFT, LEFT}; + int s = 0; + int v = 0; + int contributing = 0; + int search = 0; + int scanbeam = 0; + int sbt_entries = 0; + int vclass = 0; + int bl = 0; + int br = 0; + int tl = 0; + int tr = 0; + double *sbt = NULL; + double xb = 0.0; + double px = 0.0; + double nx = 0.0; + double yb = 0.0; + double yt = 0.0; + double dy = 0.0; + double ix = 0.0; + double iy = 0.0; + + /* Test for trivial NULL result cases */ + if (((subj->num_contours == 0) && (clip->num_contours == 0)) || + ((subj->num_contours == 0) && ((op == GPC_INT) || (op == GPC_DIFF))) || + ((clip->num_contours == 0) && (op == GPC_INT))) { + result->num_strips = 0; + result->strip = NULL; + return; + } + + /* Identify potentialy contributing contours */ + if (((op == GPC_INT) || (op == GPC_DIFF)) && (subj->num_contours > 0) && + (clip->num_contours > 0)) { + minimax_test(subj, clip, op); + } + /* Build LMT */ + if (subj->num_contours > 0) { + s_heap = build_lmt(&lmt, &sbtree, &sbt_entries, subj, SUBJ, op); + } + if (clip->num_contours > 0) { + c_heap = build_lmt(&lmt, &sbtree, &sbt_entries, clip, CLIP, op); + } + /* Return a NULL result if no contours contribute */ + if (lmt == NULL) { + result->num_strips = 0; + result->strip = NULL; + reset_lmt(&lmt); + gpc_free(s_heap); + gpc_free(c_heap); + return; + } + + /* Build scanbeam table from scanbeam tree */ + gpc_malloc(sbt, sbt_entries * sizeof(double), + const_cast("sbt creation")); + build_sbt(&scanbeam, sbt, sbtree); + scanbeam = 0; + free_sbtree(&sbtree); + + /* Invert clip polygon for difference operation */ + if (op == GPC_DIFF) { + parity[CLIP] = RIGHT; + } + local_min = lmt; + + // Process each scanbeam + while (scanbeam < sbt_entries) { + /* Set yb and yt to the bottom and top of the scanbeam */ + yb = sbt[scanbeam++]; + if (scanbeam < sbt_entries) { + yt = sbt[scanbeam]; + dy = yt - yb; + } + + /* === SCANBEAM BOUNDARY PROCESSING ================================ */ + /* If LMT node corresponding to yb exists */ + if (local_min) { + if (local_min->y == yb) { + /* Add edges starting at this local minimum to the AET */ + for (edge = local_min->first_bound; edge; edge = edge->next_bound) { + add_edge_to_aet(&aet, edge, NULL); + } + local_min = local_min->next; + } + } + /* Set dummy previous x value */ + /* Create bundles within AET */ + px = -DBL_MAX; + e0 = aet; + e1 = aet; + + /* Set up bundle fields of first edge */ + aet->bundle[ABOVE][aet->type] = (aet->top.y != yb); + aet->bundle[ABOVE][!aet->type] = 0; + aet->bstate[ABOVE] = UNBUNDLED; + + for (next_edge = aet->next; next_edge; next_edge = next_edge->next) { + /* Set up bundle fields of next edge */ + next_edge->bundle[ABOVE][next_edge->type] = (next_edge->top.y != yb); + next_edge->bundle[ABOVE][!next_edge->type] = 0; + next_edge->bstate[ABOVE] = UNBUNDLED; + + /* Bundle edges above the scanbeam boundary if they coincide */ + if (next_edge->bundle[ABOVE][next_edge->type]) { + if (gpc_eq(e0->xb, next_edge->xb) && gpc_eq(e0->dx, next_edge->dx) && + (e0->top.y != yb)) { + next_edge->bundle[ABOVE][next_edge->type] ^= + e0->bundle[ABOVE][next_edge->type]; + next_edge->bundle[ABOVE][!next_edge->type] = + e0->bundle[ABOVE][!next_edge->type]; + next_edge->bstate[ABOVE] = BUNDLE_HEAD; + e0->bundle[ABOVE][CLIP] = 0; + e0->bundle[ABOVE][SUBJ] = 0; + e0->bstate[ABOVE] = BUNDLE_TAIL; + } + e0 = next_edge; + } + } + horiz[CLIP] = NH; + horiz[SUBJ] = NH; + + /* Process each edge at this scanbeam boundary */ + for (edge = aet; edge; edge = edge->next) { + exists[CLIP] = + edge->bundle[ABOVE][CLIP] + (edge->bundle[BELOW][CLIP] << 1); + exists[SUBJ] = + edge->bundle[ABOVE][SUBJ] + (edge->bundle[BELOW][SUBJ] << 1); + + if (exists[CLIP] || exists[SUBJ]) { + /* Set bundle side */ + edge->bside[CLIP] = parity[CLIP]; + edge->bside[SUBJ] = parity[SUBJ]; + + /* Determine contributing status and quadrant occupancies */ + switch (op) { + case GPC_DIFF: + case GPC_INT: + contributing = (exists[CLIP] && (parity[SUBJ] || horiz[SUBJ])) || + (exists[SUBJ] && (parity[CLIP] || horiz[CLIP])) || + (exists[CLIP] && exists[SUBJ] && + (parity[CLIP] == parity[SUBJ])); + br = (parity[CLIP]) && (parity[SUBJ]); + bl = (parity[CLIP] ^ edge->bundle[ABOVE][CLIP]) && + (parity[SUBJ] ^ edge->bundle[ABOVE][SUBJ]); + tr = (parity[CLIP] ^ (horiz[CLIP] != NH)) && + (parity[SUBJ] ^ (horiz[SUBJ] != NH)); + tl = (parity[CLIP] ^ (horiz[CLIP] != NH) ^ + edge->bundle[BELOW][CLIP]) && + (parity[SUBJ] ^ (horiz[SUBJ] != NH) ^ + edge->bundle[BELOW][SUBJ]); + break; + case GPC_XOR: + contributing = exists[CLIP] || exists[SUBJ]; + br = (parity[CLIP]) ^ (parity[SUBJ]); + bl = (parity[CLIP] ^ edge->bundle[ABOVE][CLIP]) ^ + (parity[SUBJ] ^ edge->bundle[ABOVE][SUBJ]); + tr = (parity[CLIP] ^ (horiz[CLIP] != NH)) ^ + (parity[SUBJ] ^ (horiz[SUBJ] != NH)); + tl = (parity[CLIP] ^ (horiz[CLIP] != NH) ^ + edge->bundle[BELOW][CLIP]) ^ + (parity[SUBJ] ^ (horiz[SUBJ] != NH) ^ + edge->bundle[BELOW][SUBJ]); + break; + case GPC_UNION: + contributing = (exists[CLIP] && (!parity[SUBJ] || horiz[SUBJ])) || + (exists[SUBJ] && (!parity[CLIP] || horiz[CLIP])) || + (exists[CLIP] && exists[SUBJ] && + (parity[CLIP] == parity[SUBJ])); + br = (parity[CLIP]) || (parity[SUBJ]); + bl = (parity[CLIP] ^ edge->bundle[ABOVE][CLIP]) || + (parity[SUBJ] ^ edge->bundle[ABOVE][SUBJ]); + tr = (parity[CLIP] ^ (horiz[CLIP] != NH)) || + (parity[SUBJ] ^ (horiz[SUBJ] != NH)); + tl = (parity[CLIP] ^ (horiz[CLIP] != NH) ^ + edge->bundle[BELOW][CLIP]) || + (parity[SUBJ] ^ (horiz[SUBJ] != NH) ^ + edge->bundle[BELOW][SUBJ]); + break; + } + + // Update parity + parity[CLIP] ^= edge->bundle[ABOVE][CLIP]; + parity[SUBJ] ^= edge->bundle[ABOVE][SUBJ]; + + /* Update horizontal state */ + if (exists[CLIP]) { + horiz[CLIP] = next_h_state[horiz[CLIP]] + [((exists[CLIP] - 1) << 1) + parity[CLIP]]; + } + if (exists[SUBJ]) { + horiz[SUBJ] = next_h_state[horiz[SUBJ]] + [((exists[SUBJ] - 1) << 1) + parity[SUBJ]]; + } + vclass = tr + (tl << 1) + (br << 2) + (bl << 3); + + if (contributing) { + xb = edge->xb; + switch (vclass) { + case EMN: + new_tristrip(&tlist, edge, xb, yb); + cf = edge; + break; + case ERI: + edge->outp[ABOVE] = cf->outp[ABOVE]; + if (xb != cf->xb) { + gpc_vertex_create(edge, ABOVE, RIGHT, xb, yb); + } + cf = NULL; + break; + case ELI: + gpc_vertex_create(edge, BELOW, LEFT, xb, yb); + edge->outp[ABOVE] = NULL; + cf = edge; + break; + case EMX: + if (xb != cf->xb) { + gpc_vertex_create(edge, BELOW, RIGHT, xb, yb); + } + edge->outp[ABOVE] = NULL; + cf = NULL; + break; + case IMN: + if (cft == LED) { + if (cf->bot.y != yb) { + gpc_vertex_create(cf, BELOW, LEFT, cf->xb, yb); + } + new_tristrip(&tlist, cf, cf->xb, yb); + } + edge->outp[ABOVE] = cf->outp[ABOVE]; + gpc_vertex_create(edge, ABOVE, RIGHT, xb, yb); + break; + case ILI: + new_tristrip(&tlist, edge, xb, yb); + cf = edge; + cft = ILI; + break; + case IRI: + if (cft == LED) { + if (cf->bot.y != yb) { + gpc_vertex_create(cf, BELOW, LEFT, cf->xb, yb); + } + new_tristrip(&tlist, cf, cf->xb, yb); + } + gpc_vertex_create(edge, BELOW, RIGHT, xb, yb); + edge->outp[ABOVE] = NULL; + break; + case IMX: + gpc_vertex_create(edge, BELOW, LEFT, xb, yb); + edge->outp[ABOVE] = NULL; + cft = IMX; + break; + case IMM: + gpc_vertex_create(edge, BELOW, LEFT, xb, yb); + edge->outp[ABOVE] = cf->outp[ABOVE]; + if (xb != cf->xb) { + gpc_vertex_create(cf, ABOVE, RIGHT, xb, yb); + } + cf = edge; + break; + case EMM: + gpc_vertex_create(edge, BELOW, RIGHT, xb, yb); + edge->outp[ABOVE] = NULL; + new_tristrip(&tlist, edge, xb, yb); + cf = edge; + break; + case LED: + if (edge->bot.y == yb) { + gpc_vertex_create(edge, BELOW, LEFT, xb, yb); + } + edge->outp[ABOVE] = edge->outp[BELOW]; + cf = edge; + cft = LED; + break; + case RED: + edge->outp[ABOVE] = cf->outp[ABOVE]; + if (cft == LED) { + if (cf->bot.y == yb) { + gpc_vertex_create(edge, BELOW, RIGHT, xb, yb); + } else { + if (edge->bot.y == yb) { + gpc_vertex_create(cf, BELOW, LEFT, cf->xb, yb); + gpc_vertex_create(edge, BELOW, RIGHT, xb, yb); + } + } + } else { + gpc_vertex_create(edge, BELOW, RIGHT, xb, yb); + gpc_vertex_create(edge, ABOVE, RIGHT, xb, yb); + } + cf = NULL; + break; + default: + break; + } /* End of switch */ + } /* End of contributing conditional */ + } /* End of edge exists conditional */ + } // End of AET loop + + /* Delete terminating edges from the AET, otherwise compute xt */ + for (edge = aet; edge; edge = edge->next) { + if (edge->top.y == yb) { + prev_edge = edge->prev; + next_edge = edge->next; + if (prev_edge) { + prev_edge->next = next_edge; + } else { + aet = next_edge; + } + if (next_edge) { + next_edge->prev = prev_edge; + } + + /* Copy bundle head state to the adjacent tail edge if required */ + if ((edge->bstate[BELOW] == BUNDLE_HEAD) && prev_edge) { + if (prev_edge->bstate[BELOW] == BUNDLE_TAIL) { + prev_edge->outp[BELOW] = edge->outp[BELOW]; + prev_edge->bstate[BELOW] = UNBUNDLED; + if (prev_edge->prev) { + if (prev_edge->prev->bstate[BELOW] == BUNDLE_TAIL) { + prev_edge->bstate[BELOW] = BUNDLE_HEAD; + } + } + } + } + } else { + if (edge->top.y == yt) { + edge->xt = edge->top.x; + } else { + edge->xt = edge->bot.x + edge->dx * (yt - edge->bot.y); + } + } + } + + if (scanbeam < sbt_entries) { + /* === SCANBEAM INTERIOR PROCESSING ============================== */ + build_intersection_table(&it, aet, dy); + /* Process each node in the intersection table */ + for (intersect = it; intersect; intersect = intersect->next) { + e0 = intersect->ie[0]; + e1 = intersect->ie[1]; + + /* Only generate output for contributing intersections */ + if ((e0->bundle[ABOVE][CLIP] || e0->bundle[ABOVE][SUBJ]) && + (e1->bundle[ABOVE][CLIP] || e1->bundle[ABOVE][SUBJ])) { + p = e0->outp[ABOVE]; + q = e1->outp[ABOVE]; + ix = intersect->point.x; + iy = intersect->point.y + yb; + + in[CLIP] = (e0->bundle[ABOVE][CLIP] && !e0->bside[CLIP]) || + (e1->bundle[ABOVE][CLIP] && e1->bside[CLIP]) || + (!e0->bundle[ABOVE][CLIP] && !e1->bundle[ABOVE][CLIP] && + e0->bside[CLIP] && e1->bside[CLIP]); + in[SUBJ] = (e0->bundle[ABOVE][SUBJ] && !e0->bside[SUBJ]) || + (e1->bundle[ABOVE][SUBJ] && e1->bside[SUBJ]) || + (!e0->bundle[ABOVE][SUBJ] && !e1->bundle[ABOVE][SUBJ] && + e0->bside[SUBJ] && e1->bside[SUBJ]); + + switch (op) { // Determine quadrant occupancies + case GPC_DIFF: + case GPC_INT: + tr = (in[CLIP]) && (in[SUBJ]); + tl = (in[CLIP] ^ e1->bundle[ABOVE][CLIP]) && + (in[SUBJ] ^ e1->bundle[ABOVE][SUBJ]); + br = (in[CLIP] ^ e0->bundle[ABOVE][CLIP]) && + (in[SUBJ] ^ e0->bundle[ABOVE][SUBJ]); + bl = (in[CLIP] ^ e1->bundle[ABOVE][CLIP] ^ + e0->bundle[ABOVE][CLIP]) && + (in[SUBJ] ^ e1->bundle[ABOVE][SUBJ] ^ + e0->bundle[ABOVE][SUBJ]); + break; + case GPC_XOR: + tr = (in[CLIP]) ^ (in[SUBJ]); + tl = (in[CLIP] ^ e1->bundle[ABOVE][CLIP]) ^ + (in[SUBJ] ^ e1->bundle[ABOVE][SUBJ]); + br = (in[CLIP] ^ e0->bundle[ABOVE][CLIP]) ^ + (in[SUBJ] ^ e0->bundle[ABOVE][SUBJ]); + bl = (in[CLIP] ^ e1->bundle[ABOVE][CLIP] ^ + e0->bundle[ABOVE][CLIP]) ^ + (in[SUBJ] ^ e1->bundle[ABOVE][SUBJ] ^ + e0->bundle[ABOVE][SUBJ]); + break; + case GPC_UNION: + tr = (in[CLIP]) || (in[SUBJ]); + tl = (in[CLIP] ^ e1->bundle[ABOVE][CLIP]) || + (in[SUBJ] ^ e1->bundle[ABOVE][SUBJ]); + br = (in[CLIP] ^ e0->bundle[ABOVE][CLIP]) || + (in[SUBJ] ^ e0->bundle[ABOVE][SUBJ]); + bl = (in[CLIP] ^ e1->bundle[ABOVE][CLIP] ^ + e0->bundle[ABOVE][CLIP]) || + (in[SUBJ] ^ e1->bundle[ABOVE][SUBJ] ^ + e0->bundle[ABOVE][SUBJ]); + break; + } + + vclass = tr + (tl << 1) + (br << 2) + (bl << 3); + switch (vclass) { + case EMN: + new_tristrip(&tlist, e1, ix, iy); + e0->outp[ABOVE] = e1->outp[ABOVE]; + break; + case ERI: + if (p) { + gpc_p_edge(prev_edge, e0, ABOVE); + gpc_vertex_create(prev_edge, ABOVE, LEFT, px, iy); + gpc_vertex_create(e0, ABOVE, RIGHT, ix, iy); + e1->outp[ABOVE] = e0->outp[ABOVE]; + e0->outp[ABOVE] = NULL; + } + break; + case ELI: + if (q) { + gpc_n_edge(next_edge, e1, ABOVE); + gpc_vertex_create(e1, ABOVE, LEFT, ix, iy); + gpc_vertex_create(next_edge, ABOVE, RIGHT, nx, iy); + e0->outp[ABOVE] = e1->outp[ABOVE]; + e1->outp[ABOVE] = NULL; + } + break; + case EMX: + if (p && q) { + gpc_vertex_create(e0, ABOVE, LEFT, ix, iy); + e0->outp[ABOVE] = NULL; + e1->outp[ABOVE] = NULL; + } + break; + case IMN: + gpc_p_edge(prev_edge, e0, ABOVE); + gpc_vertex_create(prev_edge, ABOVE, LEFT, px, iy); + gpc_n_edge(next_edge, e1, ABOVE); + gpc_vertex_create(next_edge, ABOVE, RIGHT, nx, iy); + new_tristrip(&tlist, prev_edge, px, iy); + e1->outp[ABOVE] = prev_edge->outp[ABOVE]; + gpc_vertex_create(e1, ABOVE, RIGHT, ix, iy); + new_tristrip(&tlist, e0, ix, iy); + next_edge->outp[ABOVE] = e0->outp[ABOVE]; + gpc_vertex_create(next_edge, ABOVE, RIGHT, nx, iy); + break; + case ILI: + if (p) { + gpc_vertex_create(e0, ABOVE, LEFT, ix, iy); + gpc_n_edge(next_edge, e1, ABOVE); + gpc_vertex_create(next_edge, ABOVE, RIGHT, nx, iy); + e1->outp[ABOVE] = e0->outp[ABOVE]; + e0->outp[ABOVE] = NULL; + } + break; + case IRI: + if (q) { + gpc_vertex_create(e1, ABOVE, RIGHT, ix, iy); + gpc_p_edge(prev_edge, e0, ABOVE); + gpc_vertex_create(prev_edge, ABOVE, LEFT, px, iy); + e0->outp[ABOVE] = e1->outp[ABOVE]; + e1->outp[ABOVE] = NULL; + } + break; + case IMX: + if (p && q) { + gpc_vertex_create(e0, ABOVE, RIGHT, ix, iy); + gpc_vertex_create(e1, ABOVE, LEFT, ix, iy); + e0->outp[ABOVE] = NULL; + e1->outp[ABOVE] = NULL; + gpc_p_edge(prev_edge, e0, ABOVE); + gpc_vertex_create(prev_edge, ABOVE, LEFT, px, iy); + new_tristrip(&tlist, prev_edge, px, iy); + gpc_n_edge(next_edge, e1, ABOVE); + gpc_vertex_create(next_edge, ABOVE, RIGHT, nx, iy); + next_edge->outp[ABOVE] = prev_edge->outp[ABOVE]; + gpc_vertex_create(next_edge, ABOVE, RIGHT, nx, iy); + } + break; + case IMM: + if (p && q) { + gpc_vertex_create(e0, ABOVE, RIGHT, ix, iy); + gpc_vertex_create(e1, ABOVE, LEFT, ix, iy); + gpc_p_edge(prev_edge, e0, ABOVE); + gpc_vertex_create(prev_edge, ABOVE, LEFT, px, iy); + new_tristrip(&tlist, prev_edge, px, iy); + gpc_n_edge(next_edge, e1, ABOVE); + gpc_vertex_create(next_edge, ABOVE, RIGHT, nx, iy); + e1->outp[ABOVE] = prev_edge->outp[ABOVE]; + gpc_vertex_create(e1, ABOVE, RIGHT, ix, iy); + new_tristrip(&tlist, e0, ix, iy); + next_edge->outp[ABOVE] = e0->outp[ABOVE]; + gpc_vertex_create(next_edge, ABOVE, RIGHT, nx, iy); + } + break; + case EMM: + if (p && q) { + gpc_vertex_create(e0, ABOVE, LEFT, ix, iy); + new_tristrip(&tlist, e1, ix, iy); + e0->outp[ABOVE] = e1->outp[ABOVE]; + } + break; + default: + break; + } /* End of switch */ + } /* End of contributing intersection conditional */ + + // Swap bundle sides in response to edge crossing + if (e0->bundle[ABOVE][CLIP]) { + e1->bside[CLIP] = !e1->bside[CLIP]; + } + if (e1->bundle[ABOVE][CLIP]) { + e0->bside[CLIP] = !e0->bside[CLIP]; + } + if (e0->bundle[ABOVE][SUBJ]) { + e1->bside[SUBJ] = !e1->bside[SUBJ]; + } + if (e1->bundle[ABOVE][SUBJ]) { + e0->bside[SUBJ] = !e0->bside[SUBJ]; + } + + /* Swap e0 and e1 bundles in the AET */ + prev_edge = e0->prev; + next_edge = e1->next; + if (e1->next) { + e1->next->prev = e0; + } + + if (e0->bstate[ABOVE] == BUNDLE_HEAD) { + search = 1; + while (search) { + prev_edge = prev_edge->prev; + if (prev_edge) { + if (prev_edge->bundle[ABOVE][CLIP] || + prev_edge->bundle[ABOVE][SUBJ] || + (prev_edge->bstate[ABOVE] == BUNDLE_HEAD)) { + search = 0; + } + } else { + search = 0; + } + } + } + if (!prev_edge) { + e1->next = aet; + aet = e0->next; + } else { + e1->next = prev_edge->next; + prev_edge->next = e0->next; + } + e0->next->prev = prev_edge; + e1->next->prev = e1; + e0->next = next_edge; + } /* End of IT loop*/ + + /* Prepare for next scanbeam */ + for (edge = aet; edge; edge = next_edge) { + next_edge = edge->next; + succ_edge = edge->succ; + + if ((edge->top.y == yt) && succ_edge) { + /* Replace AET edge by its successor */ + succ_edge->outp[BELOW] = edge->outp[ABOVE]; + succ_edge->bstate[BELOW] = edge->bstate[ABOVE]; + succ_edge->bundle[BELOW][CLIP] = edge->bundle[ABOVE][CLIP]; + succ_edge->bundle[BELOW][SUBJ] = edge->bundle[ABOVE][SUBJ]; + prev_edge = edge->prev; + if (prev_edge) { + prev_edge->next = succ_edge; + } else { + aet = succ_edge; + } + if (next_edge) { + next_edge->prev = succ_edge; + } + succ_edge->prev = prev_edge; + succ_edge->next = next_edge; + } else { + /* Update this edge */ + edge->outp[BELOW] = edge->outp[ABOVE]; + edge->bstate[BELOW] = edge->bstate[ABOVE]; + edge->bundle[BELOW][CLIP] = edge->bundle[ABOVE][CLIP]; + edge->bundle[BELOW][SUBJ] = edge->bundle[ABOVE][SUBJ]; + edge->xb = edge->xt; + } + edge->outp[ABOVE] = NULL; + } + } + } /* === END OF SCANBEAM PROCESSING ================================== */ + + // Generate result tristrip from tlist + result->strip = NULL; + result->num_strips = count_tristrips(tlist); + if (result->num_strips > 0) { + gpc_malloc(result->strip, + result->num_strips * sizeof(gpc_vertex_list), + const_cast("tristrip list creation")); + + s = 0; + for (tn = tlist; tn; tn = tnn) { + tnn = tn->next; + if (tn->active > 2) { + /* Valid tristrip: copy the vertices and free the heap */ + result->strip[s].num_vertices = tn->active; + gpc_malloc(result->strip[s].vertex, + tn->active * sizeof(gpc_vertex), + const_cast("tristrip creation")); + v = 0; + if (0) { + lt = tn->v[RIGHT]; + rt = tn->v[LEFT]; + } else { + lt = tn->v[LEFT]; + rt = tn->v[RIGHT]; + } + while (lt || rt) { + if (lt) { + ltn = lt->next; + result->strip[s].vertex[v].x = lt->x; + result->strip[s].vertex[v].y = lt->y; + v++; + gpc_free(lt); + lt = ltn; + } + if (rt) { + rtn = rt->next; + result->strip[s].vertex[v].x = rt->x; + result->strip[s].vertex[v].y = rt->y; + v++; + gpc_free(rt); + rt = rtn; + } + } + s++; + } else { + /* Invalid tristrip: just free the heap */ + for (lt = tn->v[LEFT]; lt; lt = ltn) { + ltn = lt->next; + gpc_free(lt); + } + for (rt = tn->v[RIGHT]; rt; rt = rtn) { + rtn = rt->next; + gpc_free(rt); + } + } + gpc_free(tn); + } + } + // Tidy up + reset_it(&it); + reset_lmt(&lmt); + gpc_free(c_heap); + gpc_free(s_heap); + gpc_free(sbt); +} // NOLINT + +} // namespace gpc + +#endif diff --git a/src/operators/math/gpc.h b/src/operators/math/gpc.h new file mode 100644 index 0000000000000000000000000000000000000000..2cae7fe18458ee6f42f3cc6f374982214f041f84 --- /dev/null +++ b/src/operators/math/gpc.h @@ -0,0 +1,222 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef MULTICLASSNMS_OP +#pragma once + +#include +#include +#include +#include + +namespace gpc { + +typedef enum { // Set operation type + GPC_DIFF, // Difference + GPC_INT, // Intersection + GPC_XOR, // Exclusive or + GPC_UNION // Union +} gpc_op; + +typedef struct { // Polygon vertex structure + double x; // Vertex x component + double y; // vertex y component +} gpc_vertex; + +typedef struct { // Vertex list structure + int num_vertices; // Number of vertices in list + gpc_vertex *vertex; // Vertex array pointer +} gpc_vertex_list; + +typedef struct { // Polygon set structure + int num_contours; // Number of contours in polygon + int *hole; // Hole external contour flags + gpc_vertex_list *contour; // Contour array pointer +} gpc_polygon; + +typedef struct { // Tristrip set structure + int num_strips; // Number of tristrips + gpc_vertex_list *strip; // Tristrip array pointer +} gpc_tristrip; + +typedef enum { LEFT, RIGHT } gpc_left_right; + +typedef enum { ABOVE, BELOW } gpc_above_below; + +typedef enum { CLIP, SUBJ } gpc_clip_subj; + +typedef enum { /* Edge intersection classes */ + NUL, /* Empty non-intersection */ + EMX, /* External maximum */ + ELI, /* External left intermediate */ + TED, /* Top edge */ + ERI, /* External right intermediate */ + RED, /* Right edge */ + IMM, /* Internal maximum and minimum */ + IMN, /* Internal minimum */ + EMN, /* External minimum */ + EMM, /* External maximum and minimum */ + LED, /* Left edge */ + ILI, /* Internal left intermediate */ + BED, /* Bottom edge */ + IRI, /* Internal right intermediate */ + IMX, /* Internal maximum */ + FUL /* Full non-intersection */ +} vertex_type; + +typedef enum { /* Horizontal edge states */ + NH, /* No horizontal edge */ + BH, /* Bottom horizontal edge */ + TH /* Top horizontal edge */ +} h_state; + +typedef enum { /* Edge bundle state */ + UNBUNDLED, /* Isolated edge not within a bundle */ + BUNDLE_HEAD, /* Bundle head node */ + BUNDLE_TAIL /* Passive bundle tail node */ +} bundle_state; + +typedef struct v_shape { /* Internal vertex list datatype */ + double x; /* X coordinate component */ + double y; /* Y coordinate component */ + struct v_shape *next; /* Pointer to next vertex in list */ +} vertex_node; + +typedef struct p_shape { /* Internal contour / tristrip type */ + int active; /* Active flag / vertex count */ + int hole; /* Hole / external contour flag */ + vertex_node *v[2]; /* Left and right vertex list ptrs */ + struct p_shape *next; /* Pointer to next polygon contour */ + struct p_shape *proxy; /* Pointer to actual structure used */ +} polygon_node; + +typedef struct edge_shape { + gpc_vertex vertex; /* Piggy-backed contour vertex data */ + gpc_vertex bot; /* Edge lower (x, y) coordinate */ + gpc_vertex top; /* Edge upper (x, y) coordinate */ + double xb; /* Scanbeam bottom x coordinate */ + double xt; /* Scanbeam top x coordinate */ + double dx; /* Change in x for a unit y increase */ + int type; /* Clip / subject edge flag */ + int bundle[2][2]; /* Bundle edge flags */ + int bside[2]; /* Bundle left / right indicators */ + bundle_state bstate[2]; /* Edge bundle state */ + polygon_node *outp[2]; /* Output polygon / tristrip pointer */ + struct edge_shape *prev; /* Previous edge in the AET */ + struct edge_shape *next; /* Next edge in the AET */ + struct edge_shape *pred; /* Edge connected at the lower end */ + struct edge_shape *succ; /* Edge connected at the upper end */ + struct edge_shape *next_bound; /* Pointer to next bound in LMT */ +} edge_node; + +inline bool gpc_eq(float a, float b) { return (fabs(a - b) <= 1e-6); } + +inline bool gpc_prev_index(float a, float b) { return (fabs(a - b) <= 1e-6); } + +inline int gpc_prev_index(int i, int n) { return ((i - 1 + n) % n); } + +inline int gpc_next_index(int i, int n) { return ((i + 1) % n); } + +inline int gpc_optimal(gpc_vertex *v, int i, int n) { + return (v[(i + 1) % n].y != v[i].y || v[(i - 1 + n) % n].y != v[i].y); +} + +inline int gpc_fwd_min(edge_node *v, int i, int n) { + return (v[(i + 1) % n].vertex.y > v[i].vertex.y && + v[(i - 1 + n) % n].vertex.y >= v[i].vertex.y); +} + +inline int gpc_not_fmax(edge_node *v, int i, int n) { + return (v[(i + 1) % n].vertex.y > v[i].vertex.y); +} + +inline int gpc_rev_min(edge_node *v, int i, int n) { + return (v[(i + 1) % n].vertex.y >= v[i].vertex.y && + v[(i - 1 + n) % n].vertex.y > v[i].vertex.y); +} + +inline int gpc_not_rmax(edge_node *v, int i, int n) { + return (v[(i - 1 + n) % n].vertex.y > v[i].vertex.y); +} + +// inline void gpc_p_edge(edge_node *d, edge_node *e, int p, double i, double j) +// { +inline void gpc_p_edge(edge_node *d, edge_node *e, int p) { + d = e; + do { + d = d->prev; + } while (!d->outp[p]); + // i = d->bot.x + d->dx * (j - d->bot.y); +} + +// inline void gpc_n_edge(edge_node *d, edge_node *e, int p, double i, double j) +// { +inline void gpc_n_edge(edge_node *d, edge_node *e, int p) { + d = e; + do { + d = d->next; + } while (!d->outp[p]); + // i = d->bot.x + d->dx * (j - d->bot.y); +} + +template +void gpc_malloc(T *&p, int b, char *s) { // NOLINT + if (b > 0) { + p = reinterpret_cast(malloc(b)); + + if (!p) { + fprintf(stderr, "gpc malloc failure: %s\n", s); + exit(0); + } + } else { + p = NULL; + } +} + +template +void gpc_free(T *&p) { // NOLINT + if (p) { + free(p); + p = NULL; + } +} + +/* +=========================================================================== + Public Function Prototypes +=========================================================================== +*/ + +void add_vertex(vertex_node **t, double x, double y); + +void gpc_vertex_create(edge_node *e, int p, int s, double x, double y); + +void gpc_add_contour(gpc_polygon *polygon, gpc_vertex_list *contour, int hole); + +void gpc_polygon_clip(gpc_op set_operation, gpc_polygon *subject_polygon, + gpc_polygon *clip_polygon, gpc_polygon *result_polygon); + +void gpc_tristrip_clip(gpc_op set_operation, gpc_polygon *subject_polygon, + gpc_polygon *clip_polygon, + gpc_tristrip *result_tristrip); + +void gpc_polygon_to_tristrip(gpc_polygon *polygon, gpc_tristrip *tristrip); + +void gpc_free_polygon(gpc_polygon *polygon); + +void gpc_free_tristrip(gpc_tristrip *tristrip); + +} // namespace gpc + +#endif diff --git a/src/operators/math/poly_util.cpp b/src/operators/math/poly_util.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1cc1e2a40374204c8644267e8ab84af3cba5c65a --- /dev/null +++ b/src/operators/math/poly_util.cpp @@ -0,0 +1,120 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef MULTICLASSNMS_OP + +#include "operators/math/poly_util.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +template +void Array2PointVec(const T* box, const size_t box_size, + std::vector>* vec) { + size_t pts_num = box_size / 2; + vec->resize(pts_num); + for (size_t i = 0; i < pts_num; i++) { + vec->at(i).x = box[2 * i]; + vec->at(i).y = box[2 * i + 1]; + } +} + +template +void Array2Poly(const T* box, const size_t box_size, gpc::gpc_polygon* poly) { + size_t pts_num = box_size / 2; + poly->num_contours = 1; + poly->hole = reinterpret_cast(malloc(sizeof(int))); + poly->hole[0] = 0; + poly->contour = (gpc::gpc_vertex_list*)malloc(sizeof(gpc::gpc_vertex_list)); + poly->contour->num_vertices = pts_num; + poly->contour->vertex = + (gpc::gpc_vertex*)malloc(sizeof(gpc::gpc_vertex) * pts_num); + for (size_t i = 0; i < pts_num; ++i) { + poly->contour->vertex[i].x = box[2 * i]; + poly->contour->vertex[i].y = box[2 * i + 1]; + } +} + +template void Array2Poly(const float* box, const size_t box_size, + gpc::gpc_polygon* poly); + +template +void Poly2PointVec(const gpc::gpc_vertex_list& contour, + std::vector>* vec) { + int pts_num = contour.num_vertices; + vec->resize(pts_num); + for (size_t i = 0; i < pts_num; i++) { + vec->at(i).x = contour.vertex[i].x; + vec->at(i).y = contour.vertex[i].y; + } +} + +template +T GetContourArea(const std::vector>& vec) { + int pts_num = vec.size(); + if (pts_num < 3) return T(0.); + T area = T(0.); + for (size_t i = 0; i < pts_num; ++i) { + area += vec[i].x * vec[(i + 1) % pts_num].y - + vec[i].y * vec[(i + 1) % pts_num].x; + } + return fabs(area / 2.0); +} + +template +T PolyArea(const T* box, const size_t box_size, const bool normalized) { + // If coordinate values are is invalid + // if area size <= 0, return 0. + std::vector> vec; + Array2PointVec(box, box_size, &vec); + return GetContourArea(vec); +} + +template float PolyArea(const float* box, const size_t box_size, + const bool normalized); + +template +T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size, + const bool normalized) { + gpc::gpc_polygon poly1; + gpc::gpc_polygon poly2; + Array2Poly(box1, box_size, &poly1); + Array2Poly(box2, box_size, &poly2); + gpc::gpc_polygon respoly; + gpc::gpc_op op = gpc::GPC_INT; + gpc::gpc_polygon_clip(op, &poly2, &poly1, &respoly); + + T inter_area = T(0.); + int contour_num = respoly.num_contours; + for (int i = 0; i < contour_num; ++i) { + std::vector> resvec; + Poly2PointVec(respoly.contour[i], &resvec); + inter_area += GetContourArea(resvec); + } + + gpc::gpc_free_polygon(&poly1); + gpc::gpc_free_polygon(&poly2); + gpc::gpc_free_polygon(&respoly); + return inter_area; +} + +template float PolyOverlapArea(const float* box1, const float* box2, + const size_t box_size, const bool normalized); + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/math/poly_util.h b/src/operators/math/poly_util.h new file mode 100644 index 0000000000000000000000000000000000000000..96951a0ab1ff9ab25553b7290cfbb4a21c54cfc8 --- /dev/null +++ b/src/operators/math/poly_util.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef MULTICLASSNMS_OP +#pragma once + +#include +#include "operators/math/gpc.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +template +class Point_ { + public: + // default constructor + Point_() {} + Point_(T _x, T _y) {} + Point_(const Point_& pt) {} + + Point_& operator=(const Point_& pt); + // conversion to another data type + // template operator Point_<_T>() const; + // conversion to the old-style C structures + // operator Vec() const; + + // checks whether the point is inside the specified rectangle + // bool inside(const Rect_& r) const; + T x; //!< x coordinate of the point + T y; //!< y coordinate of the point +}; + +template +void Array2PointVec(const T* box, const size_t box_size, + std::vector>* vec); + +template +void Array2Poly(const T* box, const size_t box_size, gpc::gpc_polygon* poly); + +template +void Poly2PointVec(const gpc::gpc_vertex_list& contour, + std::vector>* vec); + +template +T GetContourArea(const std::vector>& vec); + +template +T PolyArea(const T* box, const size_t box_size, const bool normalized); + +template +T PolyOverlapArea(const T* box1, const T* box2, const size_t box_size, + const bool normalized); + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/math/selected_rows_functor.h b/src/operators/math/selected_rows_functor.h index 8cf1f5ca395d111ecca90f802773703ecb3286c9..f8b5521e4d19fd3199e7b05a902c98b731c9fbd0 100644 --- a/src/operators/math/selected_rows_functor.h +++ b/src/operators/math/selected_rows_functor.h @@ -47,7 +47,7 @@ struct SelectedRowsAddTo { const int64_t input2_offset, framework::SelectedRows* input2) { auto in1_height = input1.height(); - PADDLE_MOBILE_ENFORCE(in1_height == input2->height()); + PADDLE_MOBILE_ENFORCE(in1_height == input2->height(), "height error"); auto& in1_rows = input1.rows(); auto& in2_rows = *(input2->mutable_rows()); @@ -77,13 +77,14 @@ struct SelectedRowsAddToTensor { framework::Tensor* input2) { auto in1_height = input1.height(); auto in2_dims = input2->dims(); - PADDLE_MOBILE_ENFORCE(in1_height == in2_dims[0]); + PADDLE_MOBILE_ENFORCE(in1_height == in2_dims[0], "height != dims[0]"); auto& in1_value = input1.value(); auto& in1_rows = input1.rows(); int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); - PADDLE_MOBILE_ENFORCE(in1_row_numel == input2->numel() / in1_height); + PADDLE_MOBILE_ENFORCE(in1_row_numel == input2->numel() / in1_height, + "row_numel error"); auto* in1_data = in1_value.data(); auto* input2_data = input2->data(); diff --git a/src/operators/multiclass_nms_op.cpp b/src/operators/multiclass_nms_op.cpp index 97f4f1a1c650e2810b99a2938962ee7f8371dd2f..d29b84e56521ef98d8b3fbf000e5f6fba809fea3 100644 --- a/src/operators/multiclass_nms_op.cpp +++ b/src/operators/multiclass_nms_op.cpp @@ -25,8 +25,8 @@ void MultiClassNMSOp::InferShape() const { if (input_scores_dims.size() != 3) { LOG(kLOG_ERROR) << "Input Scores size must be 3"; } - if (input_bboxes_dims[2] != 4) { - LOG(kLOG_ERROR) << "Input BBoxes 2nd dimension must be 4"; + if (input_bboxes_dims[2] % 4 != 0 || input_bboxes_dims[2] < 4) { + LOG(kLOG_ERROR) << "Input BBoxes 2nd dimension must be multiples of 4"; } if (input_bboxes_dims[1] != input_scores_dims[2]) { LOG(kLOG_ERROR) << "Predict bboxes must be equal"; diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 27ab4629f011ba25390961b2679fd8f86d213fc3..70562da8f8961daed9c0057f3ebc8e1a1a6e340e 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -471,15 +471,6 @@ class ElementwiseMulParam : OpParam { GType *input_y_; GType *out_; int axis_; -#ifdef PADDLE_MOBILE_FPGA - - private: - fpga::EWMulArgs fpga_EW_mul_args; - - public: - const fpga::EWMulArgs &FpgaArgs() const { return fpga_EW_mul_args; } - void SetFpgaArgs(const fpga::EWMulArgs &args) { fpga_EW_mul_args = args; } -#endif }; #endif @@ -488,6 +479,38 @@ template using ElementwiseAddReluParam = ElementwiseAddParam; #endif +#ifdef ELEMENTWISESUB_OP +template +class ElementwiseSubParam : OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + ElementwiseSubParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs, + const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + input_y_ = InputYFrom(inputs, scope); + out_ = OutFrom(outputs, scope); + axis_ = GetAttr("axis", attrs); + } + + const GType *InputX() const { return input_x_; } + + const GType *InputY() const { return input_y_; } + + GType *Out() const { return out_; } + + const int &Axis() const { return axis_; } + + private: + GType *input_x_; + GType *input_y_; + GType *out_; + int axis_; +}; +#endif + #ifdef MUL_OP template class MulParam : OpParam { @@ -596,15 +619,6 @@ class SumParam : public OpParam { Variable *out_var_; vector inputs_; GType *out_; -#ifdef PADDLE_MOBILE_FPGA - - private: - fpga::SumArgs fpga_sum_args; - - public: - const fpga::SumArgs &FpgaArgs() const { return fpga_sum_args; } - void SetFpgaArgs(const fpga::SumArgs &args) { fpga_sum_args = args; } -#endif }; #endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d258d20dcc037abc2754316a1d337288d55aa067..a4191954a82928b7e6cd7ea79073cc2f0142f256 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -173,6 +173,14 @@ if (NOT FOUND_MATCH) target_link_libraries(test-elementwiseadd-op paddle-mobile) # gen test + ADD_EXECUTABLE(test-elementwisesub-op operators/test_elementwise_sub_op.cpp test_helper.h test_include.h) + target_link_libraries(test-elementwisesub-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-im2sequence-op operators/test_im2sequence_op.cpp test_helper.h test_include.h) + target_link_libraries(test-im2sequence-op paddle-mobile) + + # gen test ADD_EXECUTABLE(test-concat-op operators/test_concat_op.cpp test_helper.h test_include.h) target_link_libraries(test-concat-op paddle-mobile) diff --git a/test/operators/test_elementwise_sub_op.cpp b/test/operators/test_elementwise_sub_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cfac83eff7a012d52d47f96e088bd8519603cadc --- /dev/null +++ b/test/operators/test_elementwise_sub_op.cpp @@ -0,0 +1,159 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "../test_helper.h" +#include "../test_include.h" +#include "operators/elementwise_sub_op.h" + +namespace paddle_mobile { +namespace framework { + +template +class TestElementwiseSubOp { + public: + explicit TestElementwiseSubOp(const Program p) : program_(p) { + if (use_optimize_) { + to_predict_program_ = program_.optimizeProgram; + } else { + to_predict_program_ = program_.originProgram; + } + + const std::vector> blocks = + to_predict_program_->Blocks(); + // DLOG << " **block size " << blocks.size(); + for (int i = 0; i < blocks.size(); ++i) { + std::shared_ptr block_desc = blocks[i]; + std::vector> ops = block_desc->Ops(); + // DLOG << " ops " << ops.size(); + for (int j = 0; j < ops.size(); ++j) { + std::shared_ptr op = ops[j]; + if (op->Type() == "elementwise_sub" && + op->Input("X")[0] == "sigmoid_1.tmp_0") { + DLOG << " elementwise_sub attr size: " << op->GetAttrMap().size(); + DLOG << " inputs size: " << op->GetInputs().size(); + DLOG << " outputs size: " << op->GetOutputs().size(); + + std::shared_ptr> lrn = + std::make_shared>( + op->Type(), op->GetInputs(), op->GetOutputs(), + op->GetAttrMap(), program_.scope); + ops_of_block_[*block_desc.get()].push_back(lrn); + } + } + } + } + + std::shared_ptr predict_bn(const Tensor &t1, const Tensor &t2) { + // feed + auto scope = program_.scope; + Variable *x1_feed_value = scope->Var("tmp_0"); + auto tensor_x1 = x1_feed_value->GetMutable(); + tensor_x1->ShareDataWith(t1); + + Variable *x2_feed_value = scope->Var("sigmoid_1.tmp_0"); + auto tensor_x2 = x2_feed_value->GetMutable(); + tensor_x2->ShareDataWith(t2); + + Variable *output = scope->Var("tmp_1"); + auto *output_tensor = output->GetMutable(); + output_tensor->mutable_data({1, 1, 6, 6}); + // DLOG << typeid(output_tensor).name(); + // DLOG << "output_tensor dims: " << output_tensor->dims(); + + std::shared_ptr out_tensor = std::make_shared(); + out_tensor.reset(output_tensor); + + predict_bn(t1, t2, 0); + return out_tensor; + } + + private: + const framework::Program program_; + std::shared_ptr to_predict_program_; + std::map>>> + ops_of_block_; + bool use_optimize_ = false; + + void predict_bn(const Tensor &t1, const Tensor &t2, int block_id) { + std::shared_ptr to_predict_block = + to_predict_program_->Block(block_id); + for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) { + auto op = ops_of_block_[*to_predict_block.get()][j]; + DLOG << "op -> run()"; + op->Run(); + } + } +}; + +template class TestElementwiseSubOp; +} // namespace framework +} // namespace paddle_mobile + +int main() { + DLOG << "----------**********----------"; + DLOG << "begin to run ElementwiseSub Test"; + paddle_mobile::Loader loader; + auto program = loader.Load(std::string(g_ocr) + "/model", + std::string(g_ocr) + "/params"); + + /// input x1 (1,1,6,6) + paddle_mobile::framework::Tensor inputx1; + SetupTensor(&inputx1, {1, 1, 6, 6}, static_cast(0), + static_cast(1)); + auto *inputx1_ptr = inputx1.data(); + + /// input x2 (1,1,6,6) + paddle_mobile::framework::Tensor inputx2; + SetupTensor(&inputx2, {1, 1, 6, 6}, static_cast(0), + static_cast(1)); + auto *inputx2_ptr = inputx2.data(); + + paddle_mobile::framework::TestElementwiseSubOp + testElementwiseSubOp(program); + + auto output_op = testElementwiseSubOp.predict_bn(inputx1, inputx2); + auto *output_op_ptr = output_op->data(); + + auto inputx1_dim = inputx1.numel() / inputx1.dims()[0]; + DLOG << " input1 : "; + for (int i = 0; i < inputx1.dims()[0]; ++i) { + for (int j = 0; j < inputx1_dim; ++j) { + DLOGF("%f ", inputx1_ptr[i * inputx1_dim + j]); + } + DLOGF("\n"); + } + + auto inputx2_dim = inputx2.numel() / inputx2.dims()[0]; + DLOG << " input2 : "; + for (int i = 0; i < inputx2.dims()[0]; ++i) { + for (int j = 0; j < inputx2_dim; ++j) { + DLOGF("%f ", inputx2_ptr[i * inputx2_dim + j]); + } + DLOGF("\n"); + } + + auto output_dim = output_op->numel() / output_op->dims()[0]; + DLOG << " output : "; + for (int i = 0; i < output_op->dims()[0]; ++i) { + for (int j = 0; j < output_dim; ++j) { + DLOGF("%f ", output_op_ptr[i * output_dim + j]); + } + DLOGF("\n"); + } + + return 0; +} diff --git a/test/operators/test_im2sequence_op.cpp b/test/operators/test_im2sequence_op.cpp index a7512d3bf3cffcb100fe292e50fc7b7b23fa0aa0..b45e437e12f95cd9f7050247fc03a152246d8122 100644 --- a/test/operators/test_im2sequence_op.cpp +++ b/test/operators/test_im2sequence_op.cpp @@ -12,51 +12,129 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "../executor_for_test.h" +#pragma once + +#include "../test_helper.h" #include "../test_include.h" #include "operators/im2sequence_op.h" -int main() { - paddle_mobile::Loader loader; - auto program = loader.Load(g_ocr_recg); - PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr, - "program file read fail"); +namespace paddle_mobile { +namespace framework { - Executor4Test> - executor(program, "im2sequence"); +template +class TestIm2SequenceOp { + public: + explicit TestIm2SequenceOp(const Program p) : program_(p) { + if (use_optimize_) { + to_predict_program_ = program_.optimizeProgram; + } else { + to_predict_program_ = program_.originProgram; + } - // 1. input_tensors; - vector input_tensors; + const std::vector> blocks = + to_predict_program_->Blocks(); + // DLOG << " **block size " << blocks.size(); + for (int i = 0; i < blocks.size(); ++i) { + std::shared_ptr block_desc = blocks[i]; + std::vector> ops = block_desc->Ops(); + // DLOG << " ops " << ops.size(); + for (int j = 0; j < ops.size(); ++j) { + std::shared_ptr op = ops[j]; + if (op->Type() == "im2sequence" && + op->Input("X")[0] == "conv2d_19.tmp_1") { + DLOG << " im2squence attr size: " << op->GetAttrMap().size(); + DLOG << " inputs size: " << op->GetInputs().size(); + DLOG << " outputs size: " << op->GetOutputs().size(); - Tensor input1; - auto input1_data = CreateInput(&input1, {2, 2, 3, 3}, -1, 1); - input_tensors.push_back(input1); + std::shared_ptr> lrn = + std::make_shared>( + op->Type(), op->GetInputs(), op->GetOutputs(), + op->GetAttrMap(), program_.scope); + ops_of_block_[*block_desc.get()].push_back(lrn); + } + } + } + } - // 2. input_names - vector input_names({ - "conv2d_19.tmp_1", - }); + std::shared_ptr predict_bn(const Tensor &t1) { + // feed + auto scope = program_.scope; + Variable *x1_feed_value = scope->Var("conv2d_19.tmp_1"); + auto tensor_x1 = x1_feed_value->GetMutable(); + tensor_x1->ShareDataWith(t1); - // 3. output_names - vector output_names({"im2sequence_0.tmp_0"}); + Variable *output = scope->Var("im2sequence_0.tmp_0"); + auto *output_tensor = output->GetMutable(); + output_tensor->mutable_data({2, 12}); + // DLOG << typeid(output_tensor).name(); + // DLOG << "output_tensor dims: " << output_tensor->dims(); - // 4. out_dims; - vector out_ddims; - auto out_ddim = paddle_mobile::framework::make_ddim({8, 9}); - out_ddims.push_back(out_ddim); + std::shared_ptr out_tensor = std::make_shared(); + out_tensor.reset(output_tensor); - auto output = executor.Predict(input_tensors, input_names, - output_names, out_ddims); + predict_bn(t1, 0); + return out_tensor; + } - auto output0_data = output[0]->data(); + private: + const framework::Program program_; + std::shared_ptr to_predict_program_; + std::map>>> + ops_of_block_; + bool use_optimize_ = false; - for (int j = 0; j < input_tensors[0].numel(); ++j) { - DLOG << " value of input: " << input1_data[j]; + void predict_bn(const Tensor &t1, int block_id) { + std::shared_ptr to_predict_block = + to_predict_program_->Block(block_id); + for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) { + auto op = ops_of_block_[*to_predict_block.get()][j]; + DLOG << "op -> run()"; + op->Run(); + } } +}; + +template class TestIm2SequenceOp; +} // namespace framework +} // namespace paddle_mobile - for (int j = 0; j < output[0]->numel(); ++j) { - DLOG << " value of output: " << output0_data[j]; +int main() { + DLOG << "----------**********----------"; + DLOG << "begin to run Im2Sequence Test"; + paddle_mobile::Loader loader; + auto program = loader.Load(std::string(g_eng) + "/model", + std::string(g_eng) + "/params"); + + /// input x (4,10,2,2) + paddle_mobile::framework::Tensor inputx; + SetupTensor(&inputx, {1, 2, 6, 2}, static_cast(0), + static_cast(1)); + auto *inputx_ptr = inputx.data(); + + paddle_mobile::framework::TestIm2SequenceOp + testIm2SequenceOp(program); + + auto output_op = testIm2SequenceOp.predict_bn(inputx); + auto *output_op_ptr = output_op->data(); + + auto input_dim = inputx.numel() / inputx.dims()[0]; + DLOG << " input : "; + for (int i = 0; i < inputx.dims()[0]; ++i) { + for (int j = 0; j < input_dim; ++j) { + DLOGF("%f ", inputx_ptr[i * input_dim + j]); + } + DLOGF("\n"); } + + auto output_dim = output_op->numel() / output_op->dims()[0]; + DLOG << " output : "; + for (int i = 0; i < output_op->dims()[0]; ++i) { + for (int j = 0; j < output_dim; ++j) { + DLOGF("%f ", output_op_ptr[i * output_dim + j]); + } + DLOGF("\n"); + } + return 0; } diff --git a/test/operators/test_multiclass_nms_op.cpp b/test/operators/test_multiclass_nms_op.cpp index e6c41bd4b3bb241964a23accf4633e65818465be..d1b98d4965fd182ab1adc480279f38cea53974be 100644 --- a/test/operators/test_multiclass_nms_op.cpp +++ b/test/operators/test_multiclass_nms_op.cpp @@ -127,18 +127,25 @@ int main() { DLOG << "----------**********----------"; DLOG << "begin to run MulticlassNMS Test"; paddle_mobile::Loader loader; - auto program = loader.Load(std::string("../../test/models/mobilenet+ssd")); + auto program = loader.Load(std::string(g_mobilenet_ssd)); - /// input x (1,3,300,300) paddle_mobile::framework::Tensor inputx1; - SetupTensor(&inputx1, {10, 1917, 4}, static_cast(0), + SetupTensor(&inputx1, {1, 2, 4}, static_cast(0), static_cast(1)); auto *inputx1_ptr = inputx1.data(); + const float x1[] = {0, 0, 100, 100, 50, 50, 150, 150}; + for (int i = 0; i < 8; ++i) { + *(inputx1_ptr + i) = x1[i]; + } paddle_mobile::framework::Tensor inputx2; - SetupTensor(&inputx2, {10, 21, 1917}, static_cast(0), + SetupTensor(&inputx2, {1, 2, 2}, static_cast(0), static_cast(1)); auto *inputx2_ptr = inputx2.data(); + const float x2[] = {0.4, 0.3, 0.6, 0.7}; + for (int i = 0; i < 4; ++i) { + *(inputx2_ptr + i) = x2[i]; + } paddle_mobile::framework::TestMultiClassNMSOp testMultiClassNMSOp(program); @@ -146,8 +153,26 @@ int main() { auto output = testMultiClassNMSOp.predict(inputx1, inputx2); auto *output_ptr = output->data(); - for (int i = 0; i < output->numel(); i++) { + for (int i = 0; i < output->numel(); ++i) { DLOG << output_ptr[i]; } + + // test multi point + paddle_mobile::framework::Tensor inputx3; + SetupTensor(&inputx3, {1, 2, 8}, static_cast(0), + static_cast(1)); + auto *inputx3_ptr = inputx3.data(); + const float x3[] = {0, 0, 100, 0, 100, 100, 0, 100, + 50, 50, 150, 50, 150, 150, 50, 150}; + for (int i = 0; i < 16; ++i) { + *(inputx3_ptr + i) = x3[i]; + } + + auto output2 = testMultiClassNMSOp.predict(inputx3, inputx2); + auto *output_ptr2 = output2->data(); + + for (int i = 0; i < output2->numel(); ++i) { + DLOG << output_ptr2[i]; + } return 0; } diff --git a/tools/op.cmake b/tools/op.cmake index 4795568b8e64549d3d21fd5546ff2eec15a05012..3abe18bb7c74362bda4d564cea61ba31d61404bd 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -189,6 +189,8 @@ if(NOT FOUND_MATCH) set(CONV_OP ON) set(DEPTHWISECONV_OP ON) set(ELEMENTWISEADD_OP ON) + set(ELEMENTWISESUB_OP ON) + set(IM2SEQUENCE_OP ON) set(FUSION_CONVADD_OP ON) set(FUSION_CONVADDPRELU_OP ON) set(FUSION_CONVADDRELU_OP ON) @@ -264,6 +266,9 @@ endif() if (ELEMENTWISEADD_OP) add_definitions(-DELEMENTWISEADD_OP) endif() +if (ELEMENTWISESUB_OP) + add_definitions(-DELEMENTWISESUB_OP) +endif() if (FUSION_CONVADD_OP) add_definitions(-DFUSION_CONVADD_OP) endif()