提交 53b05ce8 编写于 作者: J juncaipeng 提交者: Xiaoyang LI

ad ops for faster rcnn, including affine_channel, anchor_generator,...

ad ops for faster rcnn, including affine_channel, anchor_generator, generate_proposals and roi_align (#1895)

* add ops for faster rcnn

* disable test for generate_proposals and roi_align, test=develop

* remove .swp file

* remove log in tensor slice

* finish the unit test for roi_align, test=develop
上级 0c25428c
......@@ -100,9 +100,13 @@ USE_LITE_KERNEL(shape, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(fill_constant, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(cast, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(slice, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(affine_channel, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(anchor_generator, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(generate_proposals, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(squeeze, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(squeeze2, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(expand, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(roi_align, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32);
......
......@@ -113,6 +113,10 @@ USE_LITE_OP(is_empty)
USE_LITE_OP(shape)
USE_LITE_OP(slice)
USE_LITE_OP(cast)
USE_LITE_OP(affine_channel)
USE_LITE_OP(anchor_generator)
USE_LITE_OP(generate_proposals)
USE_LITE_OP(squeeze) // for x2paddle
USE_LITE_OP(squeeze2) // for x2paddle
USE_LITE_OP(expand) // for x2paddle
USE_LITE_OP(roi_align)
......@@ -103,6 +103,8 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
sequence_pool.cc
sequence_expand.cc
slice.cc
affine_channel.cc
anchor_generator.cc
DEPS ${lite_kernel_deps})
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/arm/math/affine_channel.h"
#include <algorithm>
#include <limits>
#include <memory>
#include "lite/arm/math/axpy.h"
#include "lite/arm/math/funcs.h"
#include "lite/arm/math/saturate.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void affine_channel_func(const float* x,
const float* scale,
const float* bias,
const std::string data_layout,
int num,
int channel,
int height,
int width,
float* out) {
if (data_layout == "NCHW") {
int hw_size = height * width;
for (int n = 0; n < num; n++) {
for (int c = 0; c < channel; c++) {
const float* x_ptr = x + n * channel * hw_size + c * hw_size;
const float* scale_ptr = scale + c;
const float* bias_ptr = bias + c;
float* out_ptr = out + n * channel * hw_size + c * hw_size;
for (int i = 0; i < hw_size; i++) {
*out_ptr = (*x_ptr) * (*scale_ptr) + (*bias_ptr);
x_ptr++;
out_ptr++;
}
}
}
} else if (data_layout == "NHWC") {
int nhw = num * height * width;
for (int i = 0; i < nhw; i++) {
const float* x_ptr = x + i * channel;
float* out_ptr = out + i * channel;
for (int c = 0; c < channel; c++) {
*out_ptr = (*x_ptr) * scale[c] + bias[c];
x_ptr++;
out_ptr++;
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "lite/operators/op_params.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void affine_channel_func(const float* x,
const float* scale,
const float* bias,
const std::string data_layout,
int num,
int channel,
int h,
int w,
float* dout);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/arm/math/anchor_generator.h"
#include <algorithm>
#include <limits>
#include <memory>
#include "lite/arm/math/funcs.h"
#include "lite/arm/math/saturate.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void anchor_generator_func(int feature_height,
int feature_width,
std::vector<float> anchor_sizes,
std::vector<float> aspect_ratios,
std::vector<float> stride,
std::vector<float> variances,
float offset,
float* anchors_ptr,
float* vars_ptr) {
float stride_width = stride[0];
float stride_height = stride[1];
int num_anchors = aspect_ratios.size() * anchor_sizes.size();
for (int h_idx = 0; h_idx < feature_height; ++h_idx) {
float* anchors_ptr_h =
anchors_ptr + h_idx * feature_width * num_anchors * 4;
for (int w_idx = 0; w_idx < feature_width; ++w_idx) {
float* anchors_ptr_w = anchors_ptr_h + w_idx * num_anchors * 4;
float x_ctr = (w_idx * stride_width) + offset * (stride_width - 1);
float y_ctr = (h_idx * stride_height) + offset * (stride_height - 1);
float area, area_ratios;
float base_w, base_h;
float scale_w, scale_h;
float anchor_width, anchor_height;
int idx = 0;
for (size_t r = 0; r < aspect_ratios.size(); ++r) {
auto ar = aspect_ratios[r];
for (size_t s = 0; s < anchor_sizes.size(); ++s) {
auto anchor_size = anchor_sizes[s];
area = stride_width * stride_height;
area_ratios = area / ar;
base_w = round(sqrt(area_ratios));
base_h = round(base_w * ar);
scale_w = anchor_size / stride_width;
scale_h = anchor_size / stride_height;
anchor_width = scale_w * base_w;
anchor_height = scale_h * base_h;
anchors_ptr_w[idx++] = x_ctr - 0.5 * (anchor_width - 1);
anchors_ptr_w[idx++] = y_ctr - 0.5 * (anchor_height - 1);
anchors_ptr_w[idx++] = x_ctr + 0.5 * (anchor_width - 1);
anchors_ptr_w[idx++] = y_ctr + 0.5 * (anchor_height - 1);
}
}
}
}
int64_t hwn = feature_height * feature_width * num_anchors * 4;
for (int64_t i = 0; i < hwn; i++) {
*vars_ptr = variances[i % 4];
vars_ptr++;
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "lite/operators/op_params.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void anchor_generator_func(int feature_height,
int feature_widht,
std::vector<float> anchor_sizes,
std::vector<float> aspect_ratios,
std::vector<float> stride,
std::vector<float> variances,
float offset,
float* anchors_data,
float* variances_data);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -19,6 +19,8 @@
#include <cmath>
#include "lite/arm/math/activation.h"
#include "lite/arm/math/affine_channel.h"
#include "lite/arm/math/anchor_generator.h"
#include "lite/arm/math/argmax.h"
#include "lite/arm/math/axpy.h"
#include "lite/arm/math/beam_search.h"
......
......@@ -212,6 +212,7 @@ TensorLite TensorLite::Slice(int64_t begin, int64_t end) const {
dst_dims[0] = end - begin;
dst.Resize(dst_dims);
dst.offset_ = offset_ + static_cast<size_t>(begin * base) * sizeof(T);
return dst;
}
template <typename TensorT>
......
......@@ -46,6 +46,10 @@ add_kernel(sequence_expand_compute_arm ARM basic SRCS sequence_expand_compute.cc
add_kernel(im2sequence_compute_arm ARM basic SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_pool_compute_arm ARM basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(assign_compute_arm ARM basic SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(anchor_generator_compute_arm ARM basic SRCS anchor_generator_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(generate_proposals_compute_arm ARM basic SRCS generate_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(roi_align_compute_arm ARM basic SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm)
# for OCR specific
add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/affine_channel_compute.h"
#include <string>
#include <vector>
#include "lite/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void AffineChannelCompute::Run() {
auto& param = Param<operators::AffineChannelParam>();
const lite::Tensor* x = param.X;
const lite::Tensor* scale = param.Scale;
const lite::Tensor* bias = param.Bias;
const std::string data_layout = param.data_layout;
lite::Tensor* out = param.Out;
auto x_dims = x->dims();
int num = x_dims[0];
int channel = 0;
int h = 0;
int w = 0;
if (data_layout == "NCHW") {
channel = x_dims[1];
h = x_dims[2];
w = x_dims[3];
} else if (data_layout == "NHWC") {
channel = x_dims[3];
h = x_dims[1];
w = x_dims[2];
}
lite::arm::math::affine_channel_func(x->data<float>(),
scale->data<float>(),
bias->data<float>(),
data_layout,
num,
channel,
h,
w,
out->mutable_data<float>());
return;
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(affine_channel,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::AffineChannelCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/affine_channel_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class AffineChannelCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::AffineChannelParam;
void Run() override;
virtual ~AffineChannelCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/anchor_generator_compute.h"
#include <string>
#include <vector>
#include "lite/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void AnchorGeneratorCompute::Run() {
auto& param = Param<operators::AnchorGeneratorParam>();
auto* anchors = param.Anchors;
auto* variances = param.Variances;
auto* input = param.Input;
float* anchors_data = anchors->mutable_data<float>();
float* variances_data = variances->mutable_data<float>();
auto input_dims = input->dims();
int feature_height = input_dims[2];
int feature_width = input_dims[3];
lite::arm::math::anchor_generator_func(feature_height,
feature_width,
param.anchor_sizes,
param.aspect_ratios,
param.stride,
param.variances,
param.offset,
anchors_data,
variances_data);
return;
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(anchor_generator,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::AnchorGeneratorCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Anchors", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Variances", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/anchor_generator_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class AnchorGeneratorCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::AnchorGeneratorParam;
void Run() override;
virtual ~AnchorGeneratorCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/generate_proposals_compute.h"
#include <string>
#include <utility>
#include <vector>
#include "lite/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
static void permute(const Tensor &input,
Tensor *output,
const std::vector<int> &orders) {
auto in_dims = input.dims();
auto out_dims = output->dims();
int num_axes = in_dims.size();
int count = in_dims.production();
const float *din = input.data<float>();
float *dout = output->mutable_data<float>();
std::vector<int> old_steps(
{static_cast<int>(in_dims[1] * in_dims[2] * in_dims[3]),
static_cast<int>(in_dims[2] * in_dims[3]),
static_cast<int>(in_dims[3]),
1});
std::vector<int> new_steps(
{static_cast<int>(out_dims[1] * out_dims[2] * out_dims[3]),
static_cast<int>(out_dims[2] * out_dims[3]),
static_cast<int>(out_dims[3]),
1});
for (int i = 0; i < count; ++i) {
int old_idx = 0;
int idx = i;
for (int j = 0; j < num_axes; ++j) {
int order = orders[j];
old_idx += (idx / new_steps[j]) * old_steps[order];
idx %= new_steps[j];
}
dout[i] = din[old_idx];
}
}
template <typename T, typename IndexT = int>
static void gather(const Tensor &src, const Tensor &index, Tensor *output) {
auto *p_src = src.data<T>();
auto *p_index = index.data<IndexT>();
auto *p_output = output->mutable_data<T>();
auto src_dims = src.dims();
int slice_size = 1;
for (int i = 1; i < src_dims.size(); i++) slice_size *= src_dims[i];
size_t slice_bytes = slice_size * sizeof(T);
int64_t index_size = index.numel();
for (int64_t i = 0; i < index_size; i++) {
IndexT index_ = p_index[i];
memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
}
}
template <class T>
static void BoxCoder(Tensor *all_anchors,
Tensor *bbox_deltas,
Tensor *variances,
Tensor *proposals) {
T *proposals_data = proposals->mutable_data<T>();
int64_t row = all_anchors->dims()[0];
int64_t len = all_anchors->dims()[1];
auto *bbox_deltas_data = bbox_deltas->data<T>();
auto *anchor_data = all_anchors->data<T>();
const T *variances_data = nullptr;
if (variances) {
variances_data = variances->data<T>();
}
for (int64_t i = 0; i < row; ++i) {
T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len] + 1.0;
T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1] + 1.0;
T anchor_center_x = anchor_data[i * len] + 0.5 * anchor_width;
T anchor_center_y = anchor_data[i * len + 1] + 0.5 * anchor_height;
T bbox_center_x = 0, bbox_center_y = 0;
T bbox_width = 0, bbox_height = 0;
if (variances) {
bbox_center_x =
variances_data[i * len] * bbox_deltas_data[i * len] * anchor_width +
anchor_center_x;
bbox_center_y = variances_data[i * len + 1] *
bbox_deltas_data[i * len + 1] * anchor_height +
anchor_center_y;
bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2],
kBBoxClipDefault)) *
anchor_width;
bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3],
kBBoxClipDefault)) *
anchor_height;
} else {
bbox_center_x =
bbox_deltas_data[i * len] * anchor_width + anchor_center_x;
bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
kBBoxClipDefault)) *
anchor_width;
bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3],
kBBoxClipDefault)) *
anchor_height;
}
proposals_data[i * len] = bbox_center_x - bbox_width / 2;
proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2;
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2 - 1;
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2 - 1;
}
// return proposals;
}
template <class T>
static void ClipTiledBoxes(const Tensor &im_info, Tensor *boxes) {
T *boxes_data = boxes->mutable_data<T>();
const T *im_info_data = im_info.data<T>();
T zero(0);
for (int64_t i = 0; i < boxes->numel(); ++i) {
if (i % 4 == 0) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero);
} else if (i % 4 == 1) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero);
} else if (i % 4 == 2) {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero);
} else {
boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero);
}
}
}
template <class T>
static void FilterBoxes(Tensor *boxes,
float min_size,
const Tensor &im_info,
Tensor *keep) {
T *boxes_data = boxes->mutable_data<T>();
const T *im_info_data = im_info.data<T>();
T im_scale = im_info_data[2];
min_size = std::max(min_size, 1.0f);
keep->Resize(std::vector<int64_t>({boxes->dims()[0]}));
int *keep_data = keep->mutable_data<int>();
int keep_len = 0;
for (int i = 0; i < boxes->dims()[0]; ++i) {
T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1;
T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1;
T ws_origin_scale =
(boxes_data[4 * i + 2] - boxes_data[4 * i]) / im_scale + 1;
T hs_origin_scale =
(boxes_data[4 * i + 3] - boxes_data[4 * i + 1]) / im_scale + 1;
T x_ctr = boxes_data[4 * i] + ws / 2;
T y_ctr = boxes_data[4 * i + 1] + hs / 2;
if (ws_origin_scale >= min_size && hs_origin_scale >= min_size &&
x_ctr <= im_info_data[1] && y_ctr <= im_info_data[0]) {
keep_data[keep_len++] = i;
}
}
keep->Resize(std::vector<int64_t>({keep_len}));
}
template <class T>
static std::vector<std::pair<T, int>> GetSortedScoreIndex(
const std::vector<T> &scores) {
std::vector<std::pair<T, int>> sorted_indices;
sorted_indices.reserve(scores.size());
for (size_t i = 0; i < scores.size(); ++i) {
sorted_indices.emplace_back(scores[i], i);
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices.begin(),
sorted_indices.end(),
[](const std::pair<T, int> &a, const std::pair<T, int> &b) {
return a.first < b.first;
});
return sorted_indices;
}
template <class T>
static T BBoxArea(const T *box, bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static T JaccardOverlap(const T *box1, const T *box2, bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = std::max(T(0), inter_xmax - inter_xmin + 1);
const T inter_h = std::max(T(0), inter_ymax - inter_ymin + 1);
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <class T>
static Tensor VectorToTensor(const std::vector<T> &selected_indices,
int selected_num) {
Tensor keep_nms;
keep_nms.Resize(std::vector<int64_t>({selected_num}));
auto *keep_data = keep_nms.mutable_data<T>();
for (int i = 0; i < selected_num; ++i) {
keep_data[i] = selected_indices[i];
}
return keep_nms;
}
template <class T>
static Tensor NMS(Tensor *bbox, Tensor *scores, T nms_threshold, float eta) {
int64_t num_boxes = bbox->dims()[0];
int64_t box_size = bbox->dims()[1]; // 4: [xmin ymin xmax ymax]
std::vector<T> scores_data(num_boxes);
std::copy_n(scores->data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices =
GetSortedScoreIndex<T>(scores_data);
std::vector<int> selected_indices;
int selected_num = 0;
T adaptive_threshold = nms_threshold;
const T *bbox_data = bbox->data<T>();
while (sorted_indices.size() != 0) {
int idx = sorted_indices.back().second;
bool flag = true;
for (int kept_idx : selected_indices) {
if (flag) {
T overlap = JaccardOverlap<T>(
bbox_data + idx * box_size, bbox_data + kept_idx * box_size, false);
flag = (overlap <= adaptive_threshold);
} else {
break;
}
}
if (flag) {
selected_indices.push_back(idx);
++selected_num;
}
sorted_indices.erase(sorted_indices.end() - 1);
if (flag && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
return VectorToTensor(selected_indices, selected_num);
}
static std::pair<Tensor, Tensor> ProposalForOneImage(
const Tensor &im_info_slice,
const Tensor &anchors,
const Tensor &variances, // H * W * A * 4
const Tensor &bbox_deltas_slice, // [A, 4]
const Tensor &scores_slice, // [A, 1]
int pre_nms_top_n,
int post_nms_top_n,
float nms_thresh,
float min_size,
float eta) {
// sort scores_slice
Tensor index_t;
index_t.Resize(std::vector<int64_t>({scores_slice.numel()}));
auto *index = index_t.mutable_data<int>();
for (int i = 0; i < index_t.numel(); i++) {
index[i] = i;
}
auto *scores_data = scores_slice.data<float>();
auto compare_func = [scores_data](const int64_t &i, const int64_t &j) {
return scores_data[i] > scores_data[j];
};
if (pre_nms_top_n <= 0 || pre_nms_top_n >= scores_slice.numel()) {
std::sort(index, index + scores_slice.numel(), compare_func);
} else {
std::nth_element(index,
index + pre_nms_top_n,
index + scores_slice.numel(),
compare_func);
index_t.Resize({pre_nms_top_n});
}
Tensor scores_sel, bbox_sel, anchor_sel, var_sel;
scores_sel.Resize(std::vector<int64_t>({index_t.numel(), 1}));
bbox_sel.Resize(std::vector<int64_t>({index_t.numel(), 4}));
anchor_sel.Resize(std::vector<int64_t>({index_t.numel(), 4}));
var_sel.Resize(std::vector<int64_t>({index_t.numel(), 4}));
gather<float>(scores_slice, index_t, &scores_sel);
gather<float>(bbox_deltas_slice, index_t, &bbox_sel);
gather<float>(anchors, index_t, &anchor_sel);
gather<float>(variances, index_t, &var_sel);
Tensor proposals;
proposals.Resize(std::vector<int64_t>({index_t.numel(), 4}));
BoxCoder<float>(&anchor_sel, &bbox_sel, &var_sel, &proposals);
ClipTiledBoxes<float>(im_info_slice, &proposals);
Tensor keep;
FilterBoxes<float>(&proposals, min_size, im_info_slice, &keep);
Tensor scores_filter;
scores_filter.Resize(std::vector<int64_t>({keep.numel(), 1}));
bbox_sel.Resize(std::vector<int64_t>({keep.numel(), 4}));
gather<float>(scores_sel, keep, &scores_filter);
gather<float>(proposals, keep, &bbox_sel);
if (nms_thresh <= 0) {
return std::make_pair(bbox_sel, scores_filter);
}
Tensor keep_nms = NMS<float>(&bbox_sel, &scores_filter, nms_thresh, eta);
if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) {
keep_nms.Resize(std::vector<int64_t>({post_nms_top_n}));
}
proposals.Resize(std::vector<int64_t>({keep_nms.numel(), 4}));
scores_sel.Resize(std::vector<int64_t>({keep_nms.numel(), 1}));
gather<float>(bbox_sel, keep_nms, &proposals);
gather<float>(scores_filter, keep_nms, &scores_sel);
return std::make_pair(proposals, scores_sel);
}
void AppendTensor(Tensor *dst, int64_t offset, const Tensor &src) {
auto *out_data = static_cast<void *>(dst->mutable_data<float>());
auto *to_add_data = static_cast<const void *>(src.data<float>());
size_t size_of_t = sizeof(float);
offset *= size_of_t;
std::memcpy(
reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(out_data) + offset),
to_add_data,
src.numel() * size_of_t);
}
void GenerateProposalsCompute::Run() {
auto &ctx = this->ctx_->template As<ARMContext>();
auto &param = Param<operators::GenerateProposalsParam>();
auto *scores = param.Scores; // N * A * H * W
auto *bbox_deltas = param.BboxDeltas; // N * 4A * H * W
auto *im_info = param.ImInfo; // N * 3
auto *anchors = param.Anchors; // H * W * A * 4
auto *variances = param.Variances; // H * W * A * 4
auto *rpn_rois = param.RpnRois; // A * 4
auto *rpn_roi_probs = param.RpnRoiProbs; // A * 1
int pre_nms_top_n = param.pre_nms_topN;
int post_nms_top_n = param.post_nms_topN;
float nms_thresh = param.nms_thresh;
float min_size = param.min_size;
float eta = param.eta;
auto &scores_dim = scores->dims();
int64_t num = scores_dim[0];
int64_t c_score = scores_dim[1];
int64_t h_score = scores_dim[2];
int64_t w_score = scores_dim[3];
auto &bbox_dim = bbox_deltas->dims();
int64_t c_bbox = bbox_dim[1];
int64_t h_bbox = bbox_dim[2];
int64_t w_bbox = bbox_dim[3];
/*
LOG(INFO) << "scores dims:" << scores->dims() << " "
<< (scores->data<float>())[scores->numel() - 1];
LOG(INFO) << "bbox_deltas dims:" << bbox_deltas->dims() << " "
<< (bbox_deltas->data<float>())[bbox_deltas->numel() - 1];
LOG(INFO) << "im_info dims:" << im_info->dims() << " "
<< (im_info->data<float>())[im_info->numel() - 1];
LOG(INFO) << "anchors dims:" << anchors->dims() << " "
<< (anchors->data<float>())[anchors->numel() - 1];
LOG(INFO) << "variances dims:" << variances->dims() << " "
<< (variances->data<float>())[variances->numel() - 1];
*/
rpn_rois->Resize({scores->numel(), 4});
rpn_roi_probs->Resize(std::vector<int64_t>({scores->numel(), 1}));
Tensor bbox_deltas_swap, scores_swap;
scores_swap.Resize(std::vector<int64_t>({num, h_score, w_score, c_score}));
bbox_deltas_swap.Resize(std::vector<int64_t>({num, h_bbox, w_bbox, c_bbox}));
std::vector<int> orders({0, 2, 3, 1});
permute(*scores, &scores_swap, orders);
permute(*bbox_deltas, &bbox_deltas_swap, orders);
LoD lod;
lod.resize(1);
auto &lod0 = lod[0];
lod0.push_back(0);
anchors->Resize(std::vector<int64_t>({anchors->numel() / 4, 4}));
variances->Resize(std::vector<int64_t>({variances->numel() / 4, 4}));
int64_t num_proposals = 0;
for (int64_t i = 0; i < num; ++i) {
Tensor im_info_slice = im_info->Slice<float>(i, i + 1);
Tensor bbox_deltas_slice = bbox_deltas_swap.Slice<float>(i, i + 1);
Tensor scores_slice = scores_swap.Slice<float>(i, i + 1);
bbox_deltas_slice.Resize(
std::vector<int64_t>({c_bbox * h_bbox * w_bbox / 4, 4}));
scores_slice.Resize(std::vector<int64_t>({c_score * h_score * w_score, 1}));
std::pair<Tensor, Tensor> tensor_pair =
ProposalForOneImage(im_info_slice,
*anchors,
*variances,
bbox_deltas_slice,
scores_slice,
pre_nms_top_n,
post_nms_top_n,
nms_thresh,
min_size,
eta);
Tensor &proposals = tensor_pair.first;
Tensor &scores = tensor_pair.second;
AppendTensor(rpn_rois, 4 * num_proposals, proposals);
AppendTensor(rpn_roi_probs, num_proposals, scores);
num_proposals += proposals.dims()[0];
lod0.push_back(num_proposals);
}
rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod);
rpn_rois->Resize({num_proposals, 4});
rpn_roi_probs->Resize({num_proposals, 1});
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(generate_proposals,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::GenerateProposalsCompute,
def)
.BindInput("Scores", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("BboxDeltas", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("ImInfo", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Anchors", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Variances", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("RpnRois", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("RpnRoiProbs", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/generate_proposals_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class GenerateProposalsCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::GenerateProposalsParam;
void Run() override;
virtual ~GenerateProposalsCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/roi_align_compute.h"
#include <string>
#include <vector>
#include "lite/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
static constexpr int kROISize = 4;
template <class T>
void PreCalcForBilinearInterpolate(const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int iy_upper,
const int ix_upper,
T roi_ymin,
T roi_xmin,
T bin_size_h,
T bin_size_w,
int roi_bin_grid_h,
int roi_bin_grid_w,
Tensor* pre_pos,
Tensor* pre_w) {
int pre_calc_index = 0;
int* pre_pos_data = pre_pos->mutable_data<int>();
T* pre_w_data = pre_w->mutable_data<T>();
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
for (int iy = 0; iy < iy_upper; iy++) {
// calculate y of sample points
T y = roi_ymin + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
// calculate x of samle points
for (int ix = 0; ix < ix_upper; ix++) {
T x = roi_xmin + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
// deal with elements out of map
if (y < -1.0 || y > height || x < -1.0 || x > width) {
for (int i = 0; i < kROISize; ++i) {
pre_pos_data[i + pre_calc_index * kROISize] = 0;
pre_w_data[i + pre_calc_index * kROISize] = 0;
}
pre_calc_index += 1;
continue;
}
y = y <= 0 ? 0 : y;
x = x <= 0 ? 0 : x;
int y_low = static_cast<int>(y);
int x_low = static_cast<int>(x);
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = static_cast<T>(y_low);
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = static_cast<T>(x_low);
} else {
x_high = x_low + 1;
}
T ly = y - y_low, lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
pre_pos_data[pre_calc_index * kROISize] = y_low * width + x_low;
pre_pos_data[pre_calc_index * kROISize + 1] = y_low * width + x_high;
pre_pos_data[pre_calc_index * kROISize + 2] = y_high * width + x_low;
pre_pos_data[pre_calc_index * kROISize + 3] = y_high * width + x_high;
pre_w_data[pre_calc_index * kROISize] = hy * hx;
pre_w_data[pre_calc_index * kROISize + 1] = hy * lx;
pre_w_data[pre_calc_index * kROISize + 2] = ly * hx;
pre_w_data[pre_calc_index * kROISize + 3] = ly * lx;
pre_calc_index += 1;
}
}
}
}
}
void RoiAlignCompute::Run() {
auto& param = Param<operators::RoiAlignParam>();
auto* in = param.X;
auto* rois = param.ROIs;
auto* out = param.Out;
float spatial_scale = param.spatial_scale;
int pooled_height = param.pooled_height;
int pooled_width = param.pooled_width;
int sampling_ratio = param.sampling_ratio;
auto in_dims = in->dims();
int batch_size = in_dims[0];
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
auto rois_dims = rois->dims();
int rois_num = rois_dims[0];
auto out_dims = out->dims();
if (rois_num == 0) {
return;
}
DDim in_stride({static_cast<int>(in_dims[1] * in_dims[2] * in_dims[3]),
static_cast<int>(in_dims[2] * in_dims[3]),
static_cast<int>(in_dims[3]),
1});
DDim roi_stride({static_cast<int>(rois_dims[1]), 1});
DDim out_stride({static_cast<int>(out_dims[1] * out_dims[2] * out_dims[3]),
static_cast<int>(out_dims[2] * out_dims[3]),
static_cast<int>(out_dims[3]),
1});
auto* input_data = in->data<float>();
Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>();
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
// CHECK_OR_FALSE(rois_batch_size == batch_size);
int rois_num_with_lod = rois_lod[rois_batch_size];
// CHECK_OR_FALSE(rois_num_with_lod == rois_num);
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
roi_batch_id_data[i] = n;
}
}
auto* output_data = out->mutable_data<float>();
auto* rois_data = rois->data<float>();
for (int n = 0; n < rois_num; ++n) {
int roi_batch_id = roi_batch_id_data[n];
float roi_xmin = rois_data[0] * spatial_scale;
float roi_ymin = rois_data[1] * spatial_scale;
float roi_xmax = rois_data[2] * spatial_scale;
float roi_ymax = rois_data[3] * spatial_scale;
float roi_width = std::max(roi_xmax - roi_xmin, 1.0f);
float roi_height = std::max(roi_ymax - roi_ymin, 1.0f);
float bin_size_h = roi_height / pooled_height;
float bin_size_w = roi_width / pooled_width;
const float* batch_data = input_data + roi_batch_id * in_stride[0];
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const float count = roi_bin_grid_h * roi_bin_grid_w;
Tensor pre_pos;
Tensor pre_w;
int pre_size = count * out_stride[1];
pre_pos.Resize({pre_size, kROISize});
pre_w.Resize({pre_size, kROISize});
PreCalcForBilinearInterpolate<float>(height,
width,
pooled_height,
pooled_width,
roi_bin_grid_h,
roi_bin_grid_w,
roi_ymin,
roi_xmin,
bin_size_h,
bin_size_w,
roi_bin_grid_h,
roi_bin_grid_w,
&pre_pos,
&pre_w);
const int* pre_pos_data = pre_pos.data<int>();
const float* pre_w_data = pre_w.data<float>();
for (int c = 0; c < channels; c++) {
int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
const int pool_index = ph * pooled_width + pw;
float output_val = 0;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
for (int i = 0; i < kROISize; i++) {
int pos = pre_pos_data[pre_calc_index * kROISize + i];
float w = pre_w_data[pre_calc_index * kROISize + i];
output_val += w * batch_data[pos];
}
pre_calc_index += 1;
}
}
output_val /= count;
output_data[pool_index] = output_val;
}
}
batch_data += in_stride[1];
output_data += out_stride[1];
}
rois_data += roi_stride[0];
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(roi_align,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::RoiAlignCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("ROIs", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/roi_align_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class RoiAlignCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::RoiAlignParam;
void Run() override;
virtual ~RoiAlignCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -87,6 +87,10 @@ add_operator(write_to_array_op extra SRCS write_to_array_op.cc DEPS ${op_DEPS})
add_operator(topk_op extra SRCS topk_op.cc DEPS ${op_DEPS})
add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS})
add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS})
add_operator(affine_channel_op extra SRCS affine_channel_op.cc DEPS ${op_DEPS})
add_operator(anchor_generator_op extra SRCS anchor_generator_op.cc DEPS ${op_DEPS})
add_operator(generate_proposals_op extra SRCS generate_proposals_op.cc DEPS ${op_DEPS})
add_operator(roi_align_op extra SRCS roi_align_op.cc DEPS ${op_DEPS})
if (NOT LITE_WITH_X86)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/affine_channel_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool AffineChannelOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Scale);
CHECK_OR_FALSE(param_.Bias);
CHECK_OR_FALSE(param_.Out);
const auto x_dims = param_.X->dims();
const auto scale_dims = param_.Scale->dims();
const auto bias_dims = param_.Bias->dims();
CHECK_OR_FALSE(x_dims.size() == 4);
CHECK_OR_FALSE(scale_dims.size() == 1);
CHECK_OR_FALSE(bias_dims.size() == 1);
CHECK_OR_FALSE(scale_dims == bias_dims);
const std::string data_layout = param_.data_layout;
if (data_layout == "NCHW") {
CHECK_OR_FALSE(scale_dims[0] == x_dims[1] && bias_dims[0] == x_dims[1]);
} else if (data_layout == "NHWC") {
CHECK_OR_FALSE(scale_dims[0] == x_dims[3] && bias_dims[0] == x_dims[3]);
}
return true;
}
bool AffineChannelOpLite::InferShape() const {
const auto x_dims = param_.X->dims();
param_.Out->Resize(x_dims);
return true;
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AffineChannelOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto x = op_desc.Input("X").front();
auto scale = op_desc.Input("Scale").front();
auto bias = op_desc.Input("Bias").front();
auto output = op_desc.Output("Out").front();
param_.X = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.Scale = scope->FindVar(scale)->GetMutable<lite::Tensor>();
param_.Bias = scope->FindVar(bias)->GetMutable<lite::Tensor>();
if (op_desc.HasAttr("data_layout")) {
param_.data_layout = op_desc.GetAttr<std::string>("data_layout");
}
param_.Out = scope->FindVar(output)->GetMutable<lite::Tensor>();
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(affine_channel, paddle::lite::operators::AffineChannelOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class AffineChannelOpLite : public OpLite {
public:
AffineChannelOpLite() {}
explicit AffineChannelOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "affine_channel"; }
private:
mutable AffineChannelParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/anchor_generator_op.h"
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool AnchorGeneratorOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.Input);
CHECK_OR_FALSE(param_.Anchors);
CHECK_OR_FALSE(param_.Variances);
auto input_dims = param_.Input->dims();
CHECK_OR_FALSE(input_dims.size() == 4);
return true;
}
bool AnchorGeneratorOpLite::InferShape() const {
auto input_dims = param_.Input->dims();
size_t num_anchors = param_.aspect_ratios.size() * param_.anchor_sizes.size();
std::vector<int64_t> output_shape(
{input_dims[2], input_dims[3], static_cast<int64_t>(num_anchors), 4});
param_.Anchors->Resize(output_shape);
param_.Variances->Resize(output_shape);
return true;
}
bool AnchorGeneratorOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto input_name = op_desc.Input("Input").front();
auto anchor_name = op_desc.Output("Anchors").front();
auto variances_name = op_desc.Output("Variances").front();
param_.Input = scope->FindVar(input_name)->GetMutable<lite::Tensor>();
param_.Anchors = scope->FindVar(anchor_name)->GetMutable<lite::Tensor>();
param_.Variances = scope->FindVar(variances_name)->GetMutable<lite::Tensor>();
param_.anchor_sizes = op_desc.GetAttr<std::vector<float>>("anchor_sizes");
param_.aspect_ratios = op_desc.GetAttr<std::vector<float>>("aspect_ratios");
param_.stride = op_desc.GetAttr<std::vector<float>>("stride");
if (op_desc.HasAttr("variances")) {
param_.variances = op_desc.GetAttr<std::vector<float>>("variances");
}
if (op_desc.HasAttr("offset")) {
param_.offset = op_desc.GetAttr<float>("offset");
}
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(anchor_generator,
paddle::lite::operators::AnchorGeneratorOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class AnchorGeneratorOpLite : public OpLite {
public:
AnchorGeneratorOpLite() {}
explicit AnchorGeneratorOpLite(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "anchor_generator"; }
private:
mutable AnchorGeneratorParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/generate_proposals_op.h"
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool GenerateProposalsOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.Scores);
CHECK_OR_FALSE(param_.BboxDeltas);
CHECK_OR_FALSE(param_.ImInfo);
CHECK_OR_FALSE(param_.Anchors);
CHECK_OR_FALSE(param_.Variances);
CHECK_OR_FALSE(param_.RpnRois);
CHECK_OR_FALSE(param_.RpnRoiProbs);
auto scores_dims = param_.Scores->dims();
auto bbox_dims = param_.BboxDeltas->dims();
auto im_info_dims = param_.ImInfo->dims();
auto anchors_dims = param_.Anchors->dims();
auto vars_dims = param_.Variances->dims();
CHECK_OR_FALSE(bbox_dims[1] = 4 * scores_dims[1]);
CHECK_OR_FALSE(scores_dims[1] == anchors_dims[2]);
CHECK_OR_FALSE(anchors_dims == vars_dims);
return true;
}
bool GenerateProposalsOpLite::InferShape() const {
param_.RpnRois->Resize(std::vector<int64_t>({-1, 4}));
param_.RpnRoiProbs->Resize(std::vector<int64_t>({-1, 1}));
return true;
}
bool GenerateProposalsOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
// inputs
param_.Scores = scope->FindVar(op_desc.Input("Scores").front())
->GetMutable<lite::Tensor>();
param_.BboxDeltas = scope->FindVar(op_desc.Input("BboxDeltas").front())
->GetMutable<lite::Tensor>();
param_.ImInfo = scope->FindVar(op_desc.Input("ImInfo").front())
->GetMutable<lite::Tensor>();
param_.Anchors = scope->FindVar(op_desc.Input("Anchors").front())
->GetMutable<lite::Tensor>();
param_.Variances = scope->FindVar(op_desc.Input("Variances").front())
->GetMutable<lite::Tensor>();
// attrs
param_.pre_nms_topN = op_desc.GetAttr<int>("pre_nms_topN");
param_.post_nms_topN = op_desc.GetAttr<int>("post_nms_topN");
param_.nms_thresh = op_desc.GetAttr<float>("nms_thresh");
param_.min_size = op_desc.GetAttr<float>("min_size");
param_.eta = op_desc.GetAttr<float>("eta");
// outs
param_.RpnRois = scope->FindVar(op_desc.Output("RpnRois").front())
->GetMutable<lite::Tensor>();
param_.RpnRoiProbs = scope->FindVar(op_desc.Output("RpnRoiProbs").front())
->GetMutable<lite::Tensor>();
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(generate_proposals,
paddle::lite::operators::GenerateProposalsOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class GenerateProposalsOpLite : public OpLite {
public:
GenerateProposalsOpLite() {}
explicit GenerateProposalsOpLite(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "generate_proposals"; }
private:
mutable GenerateProposalsParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -695,6 +695,46 @@ struct SliceParam {
std::vector<int> decrease_axis{};
};
struct AffineChannelParam {
const lite::Tensor* X{}; // X is 4D tensor
const lite::Tensor* Scale{};
const lite::Tensor* Bias{};
std::string data_layout{"NCHW"}; // optional string from: NHWC, NCHW.
lite::Tensor* Out{};
};
struct AnchorGeneratorParam {
const lite::Tensor* Input{};
std::vector<float> anchor_sizes{};
std::vector<float> aspect_ratios{};
std::vector<float> stride{};
std::vector<float> variances{{0.1, 0.1, 0.2, 0.2}};
float offset{0.5};
lite::Tensor* Anchors{};
lite::Tensor* Variances{};
};
struct GenerateProposalsParam {
// inputs
const lite::Tensor* Scores{};
const lite::Tensor* BboxDeltas{};
const lite::Tensor* ImInfo{};
lite::Tensor* Anchors{};
lite::Tensor* Variances{};
// attrs
int pre_nms_topN{6000};
int post_nms_topN{1000};
float nms_thresh{0.5};
float min_size{0.1};
float eta{1.0};
// outputs
lite::Tensor* RpnRois{};
lite::Tensor* RpnRoiProbs{};
};
/// ----------------------- shape operators ----------------------
/// ----------------------- squeeze operators ----------------------
struct SqueezeParam {
const lite::Tensor* X{};
......@@ -725,6 +765,16 @@ struct AssignParam {
const lite::Tensor* X{};
lite::Tensor* Out{};
};
struct RoiAlignParam {
lite::Tensor* X{};
lite::Tensor* ROIs{};
lite::Tensor* Out{};
float spatial_scale{1.0};
int pooled_height{1};
int pooled_width{1};
int sampling_ratio{-1};
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/roi_align_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool RoiAlignOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.ROIs);
CHECK_OR_FALSE(param_.Out);
auto x_dims = param_.X->dims();
auto rois_dims = param_.ROIs->dims();
CHECK_OR_FALSE(x_dims.size() == 4);
CHECK_OR_FALSE(rois_dims.size() == 2);
CHECK_OR_FALSE(rois_dims[1] == 4);
CHECK_OR_FALSE(param_.pooled_height > 0);
CHECK_OR_FALSE(param_.pooled_width > 0);
CHECK_OR_FALSE(param_.spatial_scale > 0.0f);
return true;
}
bool RoiAlignOpLite::InferShape() const {
auto x_dims = param_.X->dims();
auto rois_dims = param_.ROIs->dims();
param_.Out->Resize(
{rois_dims[0], x_dims[1], param_.pooled_height, param_.pooled_width});
return true;
}
bool RoiAlignOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
param_.X =
scope->FindVar(op_desc.Input("X").front())->GetMutable<lite::Tensor>();
param_.ROIs =
scope->FindVar(op_desc.Input("ROIs").front())->GetMutable<lite::Tensor>();
param_.spatial_scale = op_desc.GetAttr<float>("spatial_scale");
param_.pooled_height = op_desc.GetAttr<int>("pooled_height");
param_.pooled_width = op_desc.GetAttr<int>("pooled_width");
param_.sampling_ratio = op_desc.GetAttr<int>("sampling_ratio");
param_.Out =
scope->FindVar(op_desc.Output("Out").front())->GetMutable<lite::Tensor>();
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(roi_align, paddle::lite::operators::RoiAlignOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class RoiAlignOpLite : public OpLite {
public:
RoiAlignOpLite() {}
explicit RoiAlignOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "roi_align"; }
private:
mutable RoiAlignParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -43,4 +43,8 @@ endif()
lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_roi_align_compute SRCS roi_align_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
class AffineChannelComputeTester : public arena::TestCase {
protected:
std::string input_ = "x";
std::string scale_ = "scale";
std::string bias_ = "bias";
std::string output_ = "out";
std::string data_layout_ = "";
DDim x_dims_{{2, 5, 20, 30}};
public:
AffineChannelComputeTester(const Place& place,
const std::string& alias,
int n,
int c,
int h,
int w,
std::string data_layout)
: TestCase(place, alias) {
data_layout_ = data_layout;
CHECK(data_layout_ == "NCHW" || data_layout == "NHWC");
if (data_layout_ == "NCHW") {
x_dims_ = DDim(std::vector<int64_t>({n, c, h, w}));
} else if (data_layout_ == "NHWC") {
x_dims_ = DDim(std::vector<int64_t>({n, h, w, c}));
}
}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(x_dims_);
auto* output_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_);
const auto* x_data = x->data<float>();
auto* scale = scope->FindTensor(scale_);
const auto* scale_data = scale->data<float>();
auto* bias = scope->FindTensor(bias_);
const auto* bias_data = bias->data<float>();
int num = x_dims_[0];
if (data_layout_ == "NCHW") {
int channel = x_dims_[1];
int height = x_dims_[2];
int width = x_dims_[3];
int size = x_dims_[2] * x_dims_[3];
int in_channel = channel * size;
for (int n = 0; n < num; n++) {
auto x_data_n = x_data + n * in_channel;
auto output_data_n = output_data + n * in_channel;
for (int c = 0; c < channel; c++) {
auto x_data_c = x_data_n + c * size;
auto output_data_c = output_data_n + c * size;
for (int k = 0; k < size; k++) {
output_data_c[k] = scale_data[c] * x_data_c[k] + bias_data[c];
}
}
}
} else if (data_layout_ == "NHWC") {
int channel = x_dims_[3];
int height = x_dims_[1];
int width = x_dims_[2];
int hwc = height * width * channel;
int wc = width * channel;
for (int n = 0; n < num; n++) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
auto x_ptr = x_data + n * hwc + h * wc + w * channel;
auto output_ptr = output_data + n * hwc + h * wc + w * channel;
for (int c = 0; c < channel; c++) {
output_ptr[c] = x_ptr[c] * scale_data[c] + bias_data[c];
}
}
}
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("affine_channel");
op_desc->SetInput("X", {input_});
op_desc->SetInput("Scale", {scale_});
op_desc->SetInput("Bias", {bias_});
op_desc->SetAttr("data_layout", data_layout_);
op_desc->SetOutput("Out", {output_});
}
void PrepareData() override {
std::vector<float> x_data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
x_data[i] = sign * static_cast<float>(i % 128) * 0.013f + 0.001;
}
SetCommonTensor(input_, x_dims_, x_data.data());
int c = data_layout_ == "NCHW" ? x_dims_[1] : x_dims_[3];
DDim scale_dims(std::vector<int64_t>({c}));
std::vector<float> scale_data(scale_dims.production());
for (int i = 0; i < scale_dims.production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
scale_data[i] = sign * static_cast<float>(i % 128) * 0.005f + 0.001;
}
SetCommonTensor(scale_, scale_dims, scale_data.data());
DDim bias_dims(std::vector<int64_t>({c}));
std::vector<float> bias_data(bias_dims.production());
for (int i = 0; i < bias_dims.production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
bias_data[i] = sign * static_cast<float>(i % 128) * 0.005f + 0.001;
}
SetCommonTensor(bias_, bias_dims, bias_data.data());
}
};
TEST(AffineChannel, precision) {
LOG(INFO) << "test affine_channel op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (int n : {1, 5}) {
for (int c : {2, 5}) {
for (int h : {3, 10}) {
for (int w : {3, 10}) {
for (std::string data_layout : {"NCHW", "NHWC"}) {
std::unique_ptr<arena::TestCase> tester(
new AffineChannelComputeTester(
place, "def", n, c, h, w, data_layout));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
}
#endif
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
class AnchorGeneratorComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_str_ = "Input";
std::string anchors_str_ = "Anchors";
std::string variances_str_ = "Variances";
DDim input_dims_;
std::vector<float> anchor_sizes_;
std::vector<float> aspect_ratios_;
std::vector<float> stride_;
std::vector<float> variances_;
float offset_;
public:
AnchorGeneratorComputeTester(const Place& place,
const std::string& alias,
int n,
int c,
int h,
int w,
std::vector<float> anchor_sizes,
std::vector<float> aspect_ratios,
std::vector<float> stride,
std::vector<float> variances,
float offset)
: TestCase(place, alias) {
input_dims_ = DDim(std::vector<int64_t>({n, c, h, w}));
anchor_sizes_ = anchor_sizes;
aspect_ratios_ = aspect_ratios;
stride_ = stride;
variances_ = variances;
offset_ = offset;
}
void RunBaseline(Scope* scope) override {
auto* anchors = scope->NewTensor(anchors_str_);
auto* vars = scope->NewTensor(variances_str_);
CHECK(anchors);
CHECK(vars);
int num_anchors = anchor_sizes_.size() * aspect_ratios_.size();
std::vector<int64_t> output_shape(
{input_dims_[2], input_dims_[3], num_anchors, 4});
DDim output_dims(output_shape);
anchors->Resize(output_dims);
vars->Resize(output_dims);
auto* anchors_data = anchors->mutable_data<float>();
auto* vars_data = vars->mutable_data<float>();
int feature_height = input_dims_[2];
int feature_width = input_dims_[3];
float stride_width = stride_[0];
float stride_height = stride_[1];
for (int h_idx = 0; h_idx < feature_height; ++h_idx) {
for (int w_idx = 0; w_idx < feature_width; ++w_idx) {
float x_ctr = (w_idx * stride_width) + offset_ * (stride_width - 1);
float y_ctr = (h_idx * stride_height) + offset_ * (stride_height - 1);
float area, area_ratios;
float base_w, base_h;
float scale_w, scale_h;
float anchor_width, anchor_height;
auto* anchors_data_hw = anchors_data +
h_idx * feature_width * num_anchors * 4 +
w_idx * num_anchors * 4;
for (size_t r = 0; r < aspect_ratios_.size(); ++r) {
auto ar = aspect_ratios_[r];
auto* anchors_data_r = anchors_data_hw + r * anchor_sizes_.size() * 4;
for (size_t s = 0; s < anchor_sizes_.size(); ++s) {
auto anchor_size = anchor_sizes_[s];
area = stride_width * stride_height;
area_ratios = area / ar;
base_w = round(sqrt(area_ratios));
base_h = round(base_w * ar);
scale_w = anchor_size / stride_width;
scale_h = anchor_size / stride_height;
anchor_width = scale_w * base_w;
anchor_height = scale_h * base_h;
anchors_data_r[s * 4 + 0] = (x_ctr - 0.5 * (anchor_width - 1));
anchors_data_r[s * 4 + 1] = (y_ctr - 0.5 * (anchor_height - 1));
anchors_data_r[s * 4 + 2] = (x_ctr + 0.5 * (anchor_width - 1));
anchors_data_r[s * 4 + 3] = (y_ctr + 0.5 * (anchor_height - 1));
}
}
}
}
for (int h = 0; h < feature_height; h++) {
for (int w = 0; w < feature_width; w++) {
for (int n = 0; n < num_anchors; n++) {
auto vars_data_i = vars_data + h * feature_width * num_anchors * 4 +
w * num_anchors * 4 + n * 4;
for (int i = 0; i < 4; i++) {
vars_data_i[i] = variances_[i];
}
}
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("anchor_generator");
op_desc->SetInput("Input", {input_str_});
op_desc->SetAttr("anchor_sizes", anchor_sizes_);
op_desc->SetAttr("aspect_ratios", aspect_ratios_);
op_desc->SetAttr("stride", stride_);
op_desc->SetAttr("variances", variances_);
op_desc->SetAttr("offset", offset_);
op_desc->SetOutput("Anchors", {anchors_str_});
op_desc->SetOutput("Variances", {variances_str_});
}
void PrepareData() override {
std::vector<float> input_data(input_dims_.production());
for (int i = 0; i < input_dims_.production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
input_data[i] = sign * static_cast<float>(i % 128) * 0.013f + 0.001;
}
SetCommonTensor(input_str_, input_dims_, input_data.data());
}
};
TEST(AnchorGenerator, precision) {
LOG(INFO) << "test anchor_generator op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (int n : {1, 3}) {
for (int c : {3, 6}) {
for (int h : {9, 18}) {
for (int w : {9, 18}) {
for (std::string str : {"NCHW", "NHWC"}) {
std::unique_ptr<arena::TestCase> tester(
new AnchorGeneratorComputeTester(place,
"def",
n,
c,
h,
w,
{64, 128, 256, 512},
{0.5, 1.0, 2.0},
{16.0, 16.0},
{0.1, 0.1, 0.2, 0.2},
0.5));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
}
#endif
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <fstream>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
class GenerateProposalsComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string Scores_ = "Scores";
std::string BboxDeltas_ = "BboxDeltas";
std::string ImInfo_ = "ImInfo";
std::string Anchors_ = "Anchors";
std::string Variances_ = "Variances";
int pre_nms_topN_ = 6000;
int post_nms_topN_ = 1000;
float nms_thresh_ = 0.699999988079071;
float min_size_ = 0.0;
float eta_ = 1.0;
std::string RpnRois_ = "RpnRois";
std::string RpnRoiProbs_ = "RpnRoiProbs";
public:
GenerateProposalsComputeTester(const Place& place, const std::string& alias)
: TestCase(place, alias) {}
void RunBaseline(Scope* scope) override {
auto* rois = scope->NewTensor(RpnRois_);
auto* probs = scope->NewTensor(RpnRoiProbs_);
CHECK(rois);
CHECK(probs);
rois->Resize(std::vector<int64_t>({304, 4}));
probs->Resize(std::vector<int64_t>({304, 1}));
std::vector<uint64_t> lod0({0, 152, 304});
LoD lod;
lod.push_back(lod0);
rois->set_lod(lod);
probs->set_lod(lod);
auto* rois_data = rois->mutable_data<float>();
auto* probs_data = probs->mutable_data<float>();
std::string base_path = "/data/local/tmp/data_files/";
std::string filename;
std::ifstream reader;
// rois
filename = "result_generate_proposals_0.tmp_0.txt";
reader.open(base_path + filename);
for (int i = 0; i < rois->numel(); i++) {
reader >> rois_data[i];
}
LOG(INFO) << "Read Rois data." << rois_data[0];
reader.close();
// probs
filename = "result_generate_proposals_0.tmp_1.txt";
reader.open(base_path + filename);
for (int i = 0; i < probs->numel(); i++) {
reader >> probs_data[i];
}
LOG(INFO) << "Read Probs data." << probs_data[0];
reader.close();
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("generate_proposals");
op_desc->SetInput("Scores", {Scores_});
op_desc->SetInput("BboxDeltas", {BboxDeltas_});
op_desc->SetInput("ImInfo", {ImInfo_});
op_desc->SetInput("Anchors", {Anchors_});
op_desc->SetInput("Variances", {Variances_});
op_desc->SetAttr("pre_nms_topN", pre_nms_topN_);
op_desc->SetAttr("post_nms_topN", post_nms_topN_);
op_desc->SetAttr("nms_thresh", nms_thresh_);
op_desc->SetAttr("min_size", min_size_);
op_desc->SetAttr("eta", eta_);
op_desc->SetOutput("RpnRois", {RpnRois_});
op_desc->SetOutput("RpnRoiProbs", {RpnRoiProbs_});
}
void PrepareData() override {
std::string base_path = "/data/local/tmp/data_files/";
std::string filename;
DDim dims;
std::vector<float> datas;
std::ifstream reader;
// Scores
filename = "result_rpn_cls_prob.tmp_0.txt";
dims = DDim(std::vector<int64_t>({2, 15, 84, 50}));
datas.resize(dims.production());
reader.open(base_path + filename);
for (int i = 0; i < dims.production(); i++) {
reader >> datas[i];
}
LOG(INFO) << "Read Scores data." << datas[0];
reader.close();
SetCommonTensor(Scores_, dims, datas.data());
// BboxDeltas
filename = "result_rpn_bbox_pred.tmp_1.txt";
dims = DDim(std::vector<int64_t>({2, 60, 84, 50}));
datas.resize(dims.production());
reader.open(base_path + filename);
for (int i = 0; i < dims.production(); i++) {
reader >> datas[i];
}
LOG(INFO) << "Read BboxDeltas data." << datas[0];
reader.close();
reader.close();
SetCommonTensor(BboxDeltas_, dims, datas.data());
// ImInfo
filename = "result_im_info.txt";
dims = DDim(std::vector<int64_t>({2, 3}));
datas.resize(dims.production());
reader.open(base_path + filename);
for (int i = 0; i < dims.production(); i++) {
reader >> datas[i];
}
LOG(INFO) << "Read ImInfo data." << datas[0];
reader.close();
SetCommonTensor(ImInfo_, dims, datas.data());
// Anchors
filename = "result_anchor_generator_0.tmp_0.txt";
dims = DDim(std::vector<int64_t>({84, 50, 15, 4}));
datas.resize(dims.production());
reader.open(base_path + filename);
for (int i = 0; i < dims.production(); i++) {
reader >> datas[i];
}
LOG(INFO) << "Read Anchors data." << datas[0];
reader.close();
SetCommonTensor(Anchors_, dims, datas.data());
// Variances
filename = "result_anchor_generator_0.tmp_1.txt";
dims = DDim(std::vector<int64_t>({84, 50, 15, 4}));
datas.resize(dims.production());
reader.open(base_path + filename);
for (int i = 0; i < dims.production(); i++) {
reader >> datas[i];
}
LOG(INFO) << "Read Variances data." << datas[0];
reader.close();
SetCommonTensor(Variances_, dims, datas.data());
}
};
TEST(GenerateProposals, precision) {
// The unit test for generate_proposals needs the params,
// which is obtained by runing model by paddle.
LOG(INFO) << "test generate proposals op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
std::unique_ptr<arena::TestCase> tester(
new GenerateProposalsComputeTester(place, "def"));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
#endif
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <fstream>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
class RoiAlignComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string x_ = "X";
std::string rois_ = "ROIs";
std::string out_ = "Out";
float spatial_scale_ = 0.0625;
int pooled_height_ = 14;
int pooled_width_ = 14;
int sampling_ratio_ = 0;
public:
RoiAlignComputeTester(const Place& place, const std::string& alias)
: TestCase(place, alias) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(out_);
CHECK(out);
out->Resize(std::vector<int64_t>({304, 1024, 14, 14}));
/*
std::vector<uint64_t> lod0({0, 152, 304});
LoD lod;
lod.push_back(lod0);
probs->set_lod(lod);
*/
auto* out_data = out->mutable_data<float>();
std::string base_path = "/data/local/tmp/roi_align_datas/";
std::string filename;
std::ifstream reader;
// out
filename = "result_roi_align_0.tmp_0.txt";
reader.open(base_path + filename);
LOG(INFO) << "Start read out data";
for (int i = 0; i < out->numel(); i++) {
reader >> out_data[i];
}
LOG(INFO) << "Read out data. " << out_data[0] << " "
<< out_data[out->numel() - 1];
reader.close();
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("roi_align");
op_desc->SetInput("X", {x_});
op_desc->SetInput("ROIs", {rois_});
op_desc->SetAttr("spatial_scale", spatial_scale_);
op_desc->SetAttr("pooled_height", pooled_height_);
op_desc->SetAttr("pooled_width", pooled_width_);
op_desc->SetAttr("sampling_ratio", sampling_ratio_);
op_desc->SetOutput("Out", {out_});
}
void PrepareData() override {
std::string base_path = "/data/local/tmp/roi_align_datas/";
std::string filename;
DDim dims;
std::vector<float> datas;
std::ifstream reader;
// x
filename = "result_res4f.add.output.5.tmp_0.txt";
dims = DDim(std::vector<int64_t>({2, 1024, 84, 50}));
datas.resize(dims.production());
reader.open(base_path + filename);
for (int i = 0; i < dims.production(); i++) {
reader >> datas[i];
}
LOG(INFO) << "Read x data. " << datas[0] << " " << datas.back();
reader.close();
SetCommonTensor(x_, dims, datas.data());
// rois
filename = "result_generate_proposals_0.tmp_0.txt";
dims = DDim(std::vector<int64_t>({304, 4}));
datas.resize(dims.production());
reader.open(base_path + filename);
for (int i = 0; i < dims.production(); i++) {
reader >> datas[i];
}
LOG(INFO) << "Read rois data. " << datas[0] << " " << datas.back();
reader.close();
SetCommonTensor(rois_, dims, datas.data());
auto rois_tensor = baseline_scope()->FindMutableTensor(rois_);
std::vector<uint64_t> lod0({0, 152, 304});
LoD lod;
lod.push_back(lod0);
rois_tensor->set_lod(lod);
}
};
TEST(RoiAlign, precision) {
// The unit test for roi_align needs the params,
// which is obtained by runing model by paddle.
LOG(INFO) << "test roi align op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
std::unique_ptr<arena::TestCase> tester(
new RoiAlignComputeTester(place, "def"));
arena::Arena arena(std::move(tester), place, 2e-4);
arena.TestPrecision();
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册