提交 635b4958 编写于 作者: J juncaipeng 提交者: Xiaoyang LI

Add ops and fix bugs for Faster RCNN (#1942)

* add ops for faster rcnn

* disable test for generate_proposals and roi_align, test=develop

* remove .swp file

* remove log in tensor slice

* finish the unit test for roi_align, test=develop

* add box_clip op and fix tensor slice bug

* remove add four op twice

* rewrite the implement for box_coder and sequence_expand, add faster_rcnn_test, test=develop

* fix test bug of box_clip in x86 server, test=develop
上级 6d9b1558
......@@ -106,9 +106,10 @@ USE_LITE_KERNEL(generate_proposals, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(squeeze, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(squeeze2, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(expand, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(roi_align, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(box_clip, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(reduce_mean, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(stack, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(roi_align, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32);
......
......@@ -122,3 +122,4 @@ USE_LITE_OP(squeeze) // for x2paddle
USE_LITE_OP(squeeze2) // for x2paddle
USE_LITE_OP(expand) // for x2paddle
USE_LITE_OP(roi_align)
USE_LITE_OP(box_clip)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <fstream>
#include <vector>
#include "lite/api/cxx_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
#ifdef LITE_WITH_ARM
void TestModel(const std::vector<Place>& valid_places,
const Place& preferred_place) {
DeviceInfo::Init();
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
auto* input_image = predictor.GetInput(0);
input_image->Resize({1, 3, 1333, 800});
auto* input_image_data = input_image->mutable_data<float>();
std::ifstream read_file("/data/local/tmp/pjc/faster_rcnn_img.txt");
for (int i = 0; i < input_image->numel(); i++) {
read_file >> input_image_data[i];
}
read_file.close();
LOG(INFO) << "image data:" << input_image_data[0] << " "
<< input_image_data[input_image->numel() - 1];
auto* im_info = predictor.GetInput(1);
im_info->Resize({1, 3});
auto* im_info_data = im_info->mutable_data<float>();
im_info_data[0] = 1333;
im_info_data[1] = 800;
im_info_data[2] = 1;
auto* im_shape = predictor.GetInput(2);
im_shape->Resize({1, 3});
auto* im_shape_data = im_shape->mutable_data<float>();
im_shape_data[0] = 1333;
im_shape_data[1] = 800;
im_shape_data[2] = 1;
for (int i = 0; i < FLAGS_warmup; ++i) {
predictor.Run();
}
auto start = GetCurrentUS();
for (int i = 0; i < FLAGS_repeats; ++i) {
predictor.Run();
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
<< " ms in average.";
auto* out = predictor.GetOutput(0);
auto* out_data = out->data<float>();
LOG(INFO) << "==========output data===============";
for (int i = 0; i < out->numel(); i++) {
// LOG(INFO) << out_data[i];
}
/*
ASSERT_EQ(out->dims()[1], 6);
ASSERT_EQ(out->lod().size(), 1);
ASSERT_EQ(out->lod()[0].size(), 2);
ASSERT_EQ(out->lod()[0][0], 0);
ASSERT_EQ(out->lod()[0][1], 100);
*/
}
TEST(MobileNetV1_YoloV3, test_arm) {
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
});
TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)}));
}
#endif // LITE_WITH_ARM
} // namespace lite
} // namespace paddle
......@@ -113,7 +113,8 @@ class TensorLite {
// For other devices, T and R may be the same type.
template <typename T, typename R = T>
const R *data() const {
return static_cast<const R *>(buffer_->data());
return reinterpret_cast<const R *>(static_cast<char *>(buffer_->data()) +
offset_);
}
void Resize(const DDimLite &ddim) { dims_ = ddim; }
......@@ -204,7 +205,7 @@ template <typename T, typename R>
R *TensorLite::mutable_data() {
memory_size_ = dims_.production() * sizeof(T);
buffer_->ResetLazy(target_, memory_size_);
return static_cast<R *>(buffer_->data());
return reinterpret_cast<R *>(static_cast<char *>(buffer_->data()) + offset_);
}
template <typename T, typename R>
......@@ -212,7 +213,7 @@ R *TensorLite::mutable_data(TargetType target) {
target_ = target;
memory_size_ = dims_.production() * sizeof(T);
buffer_->ResetLazy(target, memory_size());
return static_cast<R *>(buffer_->data());
return reinterpret_cast<R *>(static_cast<char *>(buffer_->data()) + offset_);
}
template <typename T>
......
......@@ -51,6 +51,7 @@ add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc D
add_kernel(anchor_generator_compute_arm ARM basic SRCS anchor_generator_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(generate_proposals_compute_arm ARM basic SRCS generate_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(roi_align_compute_arm ARM basic SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(box_clip_compute_arm ARM basic SRCS box_clip_compute.cc DEPS ${lite_kernel_deps} math_arm)
# for OCR specific
add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/box_clip_compute.h"
#include <string>
#include <vector>
#include "lite/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <class T>
void ClipTiledBoxes(const Tensor& im_info,
const Tensor& input_boxes,
Tensor* out) {
T* out_data = out->mutable_data<T>();
const T* im_info_data = im_info.data<T>();
const T* input_boxes_data = input_boxes.data<T>();
T zero(0);
T im_w = round(im_info_data[1] / im_info_data[2]);
T im_h = round(im_info_data[0] / im_info_data[2]);
for (int64_t i = 0; i < input_boxes.numel(); ++i) {
if (i % 4 == 0) {
out_data[i] = std::max(std::min(input_boxes_data[i], im_w - 1), zero);
} else if (i % 4 == 1) {
out_data[i] = std::max(std::min(input_boxes_data[i], im_h - 1), zero);
} else if (i % 4 == 2) {
out_data[i] = std::max(std::min(input_boxes_data[i], im_w - 1), zero);
} else {
out_data[i] = std::max(std::min(input_boxes_data[i], im_h - 1), zero);
}
}
}
void BoxClipCompute::Run() {
auto& param = Param<operators::BoxClipParam>();
const auto* input = param.Input;
const auto* im_info = param.ImInfo;
auto* output = param.Output;
output->mutable_data<float>();
if (input->lod().size() > 1) {
LOG(FATAL) << "Only support 0 and 1 level of LoD.";
}
auto box_lod = input->lod().back();
int64_t n = static_cast<int64_t>(box_lod.size() - 1);
for (int i = 0; i < n; ++i) {
Tensor im_info_slice = im_info->Slice<float>(i, i + 1);
auto* im_info_slice_data = im_info_slice.data<float>();
Tensor box_slice = input->Slice<float>(box_lod[i], box_lod[i + 1]);
Tensor output_slice = output->Slice<float>(box_lod[i], box_lod[i + 1]);
ClipTiledBoxes<float>(im_info_slice, box_slice, &output_slice);
}
return;
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(box_clip,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::BoxClipCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("ImInfo", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/box_clip_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class BoxClipCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::BoxClipParam;
void Run() override;
virtual ~BoxClipCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -22,7 +22,143 @@ namespace lite {
namespace kernels {
namespace arm {
void EncodeCenterSize(const Tensor* target_box,
const Tensor* prior_box,
const Tensor* prior_box_var,
const bool normalized,
const std::vector<float> variance,
float* output) {
int64_t row = target_box->dims()[0];
int64_t col = prior_box->dims()[0];
int64_t len = prior_box->dims()[1];
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
auto* target_box_data = target_box->data<float>();
auto* prior_box_data = prior_box->data<float>();
int64_t offset = i * col * len + j * len;
float prior_box_width = prior_box_data[j * len + 2] -
prior_box_data[j * len] + (normalized == false);
float prior_box_height = prior_box_data[j * len + 3] -
prior_box_data[j * len + 1] +
(normalized == false);
float prior_box_center_x = prior_box_data[j * len] + prior_box_width / 2;
float prior_box_center_y =
prior_box_data[j * len + 1] + prior_box_height / 2;
float target_box_center_x =
(target_box_data[i * len + 2] + target_box_data[i * len]) / 2;
float target_box_center_y =
(target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2;
float target_box_width = target_box_data[i * len + 2] -
target_box_data[i * len] + (normalized == false);
float target_box_height = target_box_data[i * len + 3] -
target_box_data[i * len + 1] +
(normalized == false);
output[offset] =
(target_box_center_x - prior_box_center_x) / prior_box_width;
output[offset + 1] =
(target_box_center_y - prior_box_center_y) / prior_box_height;
output[offset + 2] =
std::log(std::fabs(target_box_width / prior_box_width));
output[offset + 3] =
std::log(std::fabs(target_box_height / prior_box_height));
}
}
if (prior_box_var) {
const float* prior_box_var_data = prior_box_var->data<float>();
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
for (int k = 0; k < 4; ++k) {
int64_t offset = i * col * len + j * len;
int64_t prior_var_offset = j * len;
output[offset + k] /= prior_box_var_data[prior_var_offset + k];
}
}
}
} else if (!(variance.empty())) {
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
for (int k = 0; k < 4; ++k) {
int64_t offset = i * col * len + j * len;
output[offset + k] /= static_cast<float>(variance[k]);
}
}
}
}
}
template <int axis, int var_size>
void DecodeCenterSize(const Tensor* target_box,
const Tensor* prior_box,
const Tensor* prior_box_var,
const bool normalized,
std::vector<float> variance,
float* output) {
int64_t row = target_box->dims()[0];
int64_t col = target_box->dims()[1];
int64_t len = target_box->dims()[2];
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
auto* target_box_data = target_box->data<float>();
auto* prior_box_data = prior_box->data<float>();
float var_data[4] = {1., 1., 1., 1.};
float* var_ptr = var_data;
int64_t offset = i * col * len + j * len;
int64_t prior_box_offset = axis == 0 ? j * len : i * len;
float prior_box_width = prior_box_data[prior_box_offset + 2] -
prior_box_data[prior_box_offset] +
(normalized == false);
float prior_box_height = prior_box_data[prior_box_offset + 3] -
prior_box_data[prior_box_offset + 1] +
(normalized == false);
float prior_box_center_x =
prior_box_data[prior_box_offset] + prior_box_width / 2;
float prior_box_center_y =
prior_box_data[prior_box_offset + 1] + prior_box_height / 2;
float target_box_center_x = 0, target_box_center_y = 0;
float target_box_width = 0, target_box_height = 0;
int64_t prior_var_offset = axis == 0 ? j * len : i * len;
if (var_size == 2) {
std::memcpy(var_ptr,
prior_box_var->data<float>() + prior_var_offset,
4 * sizeof(float));
} else if (var_size == 1) {
var_ptr = reinterpret_cast<float*>(variance.data());
}
float box_var_x = *var_ptr;
float box_var_y = *(var_ptr + 1);
float box_var_w = *(var_ptr + 2);
float box_var_h = *(var_ptr + 3);
target_box_center_x =
box_var_x * target_box_data[offset] * prior_box_width +
prior_box_center_x;
target_box_center_y =
box_var_y * target_box_data[offset + 1] * prior_box_height +
prior_box_center_y;
target_box_width =
std::exp(box_var_w * target_box_data[offset + 2]) * prior_box_width;
target_box_height =
std::exp(box_var_h * target_box_data[offset + 3]) * prior_box_height;
output[offset] = target_box_center_x - target_box_width / 2;
output[offset + 1] = target_box_center_y - target_box_height / 2;
output[offset + 2] =
target_box_center_x + target_box_width / 2 - (normalized == false);
output[offset + 3] =
target_box_center_y + target_box_height / 2 - (normalized == false);
}
}
}
void BoxCoderCompute::Run() {
/*
auto& param = Param<operators::BoxCoderParam>();
int axis = param.axis;
bool box_normalized = param.box_normalized;
......@@ -35,6 +171,56 @@ void BoxCoderCompute::Run() {
code_type,
box_normalized,
axis);
*/
auto& param = Param<operators::BoxCoderParam>();
auto* prior_box = param.prior_box;
auto* prior_box_var = param.prior_box_var;
auto* target_box = param.target_box;
auto* output_box = param.proposals;
std::vector<float> variance = param.variance;
const int axis = param.axis;
std::string code_type = param.code_type;
bool normalized = param.box_normalized;
auto row = target_box->dims()[0];
auto col = prior_box->dims()[0];
if (code_type == "decode_center_size") {
col = target_box->dims()[1];
}
auto len = prior_box->dims()[1];
output_box->Resize({row, col, len});
auto* output = output_box->mutable_data<float>();
if (code_type == "encode_center_size") {
EncodeCenterSize(
target_box, prior_box, prior_box_var, normalized, variance, output);
} else if (code_type == "decode_center_size") {
if (prior_box_var) {
if (axis == 0) {
DecodeCenterSize<0, 2>(
target_box, prior_box, prior_box_var, normalized, variance, output);
} else {
DecodeCenterSize<1, 2>(
target_box, prior_box, prior_box_var, normalized, variance, output);
}
} else if (!(variance.empty())) {
if (axis == 0) {
DecodeCenterSize<0, 1>(
target_box, prior_box, prior_box_var, normalized, variance, output);
} else {
DecodeCenterSize<1, 1>(
target_box, prior_box, prior_box_var, normalized, variance, output);
}
} else {
if (axis == 0) {
DecodeCenterSize<0, 0>(
target_box, prior_box, prior_box_var, normalized, variance, output);
} else {
DecodeCenterSize<1, 0>(
target_box, prior_box, prior_box_var, normalized, variance, output);
}
}
}
}
} // namespace arm
......
......@@ -405,19 +405,6 @@ void GenerateProposalsCompute::Run() {
int64_t h_bbox = bbox_dim[2];
int64_t w_bbox = bbox_dim[3];
/*
LOG(INFO) << "scores dims:" << scores->dims() << " "
<< (scores->data<float>())[scores->numel() - 1];
LOG(INFO) << "bbox_deltas dims:" << bbox_deltas->dims() << " "
<< (bbox_deltas->data<float>())[bbox_deltas->numel() - 1];
LOG(INFO) << "im_info dims:" << im_info->dims() << " "
<< (im_info->data<float>())[im_info->numel() - 1];
LOG(INFO) << "anchors dims:" << anchors->dims() << " "
<< (anchors->data<float>())[anchors->numel() - 1];
LOG(INFO) << "variances dims:" << variances->dims() << " "
<< (variances->data<float>())[variances->numel() - 1];
*/
rpn_rois->Resize({scores->numel(), 4});
rpn_roi_probs->Resize(std::vector<int64_t>({scores->numel(), 1}));
......@@ -469,6 +456,21 @@ void GenerateProposalsCompute::Run() {
rpn_roi_probs->set_lod(lod);
rpn_rois->Resize({num_proposals, 4});
rpn_roi_probs->Resize({num_proposals, 1});
/*
auto* rpn_roi_probs_data = rpn_roi_probs->data<float>();
LOG(INFO) << "rpn_roi_probs:" << rpn_roi_probs->dims();
for (int i = 0; i < rpn_roi_probs->numel() - 4; i = i + 4) {
LOG(INFO) << rpn_roi_probs_data[i] << " " << rpn_roi_probs_data[i+1]
<< " " << rpn_roi_probs_data[i+2] << " " << rpn_roi_probs_data[i+3];
}
auto* rpn_roi_data = rpn_rois->data<float>();
LOG(INFO) << "rpn_roi:" << rpn_rois->dims();
for (int i = 0; i < rpn_rois->numel() - 4; i = i + 4) {
LOG(INFO) << rpn_roi_data[i] << " " << rpn_roi_data[i+1]
<< " " << rpn_roi_data[i+2] << " " << rpn_roi_data[i+3];
}
*/
}
} // namespace arm
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/kernels/arm/sequence_expand_compute.h"
#include <vector>
#include "lite/arm/math/funcs.h"
namespace paddle {
......@@ -20,9 +21,41 @@ namespace lite {
namespace kernels {
namespace arm {
void SequenceExpandFunc(const Tensor& x,
const std::vector<uint64_t>& x_lod,
const std::vector<uint64_t>& ref_lod,
Tensor* out) {
uint64_t out_offset = 0;
int64_t x_item_length = x.numel() / x.dims()[0];
auto out_data = out->mutable_data<float>();
auto x_data = x.data<float>();
for (size_t i = 1; i < ref_lod.size(); ++i) {
uint64_t repeat_num = ref_lod[i] - ref_lod[i - 1];
uint64_t x_start = x_lod[i - 1];
uint64_t x_end = x_lod[i];
uint64_t x_seq_len = x_end - x_start;
if (repeat_num > 0) {
uint64_t out_start = out_offset;
if (out->lod().size() == 1) {
out_start = out->lod()[0][out_offset];
}
for (uint64_t j = 0; j < repeat_num; j++) {
for (uint64_t k = 0; k < x_seq_len; k++) {
for (int l = 0; l < x_item_length; l++) {
out_data[(out_start + j * x_seq_len + k) * x_item_length + l] =
x_data[(x_start + k) * x_item_length + l];
}
}
}
}
out_offset += repeat_num;
}
}
void SequenceExpandCompute::PrepareForRun() {}
void SequenceExpandCompute::Run() {
/*
auto& param = Param<operators::SequenceExpandParam>();
const float* x_data = param.X->data<float>();
int width = param.X->numel() / param.X->dims()[0];
......@@ -35,6 +68,51 @@ void SequenceExpandCompute::Run() {
}
lite::arm::math::SequenceExpandImpl(
x_data, x_lod, width, y_lod[ref_level], output);
*/
auto& param = Param<operators::SequenceExpandParam>();
auto* x = param.X;
auto* y = param.Y;
auto* out = param.Out;
int ref_level = param.ref_level;
auto x_lod = x->lod();
auto y_lod = y->lod();
if (ref_level == -1) ref_level = y_lod.size() - 1;
out->mutable_data<float>();
if (y_lod[ref_level].size() <= 1) {
out->CopyDataFrom(*x);
return;
}
std::vector<uint64_t> out_lod;
if (x_lod.size() == 1) {
out_lod.push_back(0);
uint64_t out_offset = 0;
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
uint64_t repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
uint64_t x_start = x_lod[0][i - 1];
uint64_t x_end = x_lod[0][i];
uint64_t x_seq_len = x_end - x_start;
for (uint64_t j = 0; j < repeat_num; ++j) {
out_lod.push_back(out_lod.back() + x_seq_len);
out_offset++;
}
}
// write lod to out if x has lod
auto& ref_lod = *out->mutable_lod();
ref_lod[0] = out_lod;
}
std::vector<uint64_t> ref_x_lod;
if (x->lod().size() == 1) {
ref_x_lod = x->lod()[0];
} else {
ref_x_lod.resize(x->dims()[0] + 1);
std::iota(ref_x_lod.begin(), ref_x_lod.end(), 0);
}
SequenceExpandFunc(*x, ref_x_lod, y_lod[ref_level], out);
}
} // namespace arm
......
......@@ -63,6 +63,11 @@ add_operator(reduce_mean_op basic SRCS reduce_mean_op.cc DEPS ${op_DEPS})
add_operator(stack_op basic SRCS stack_op.cc DEPS ${op_DEPS})
add_operator(cast_op_lite basic SRCS cast_op.cc DEPS ${op_DEPS})
add_operator(assign_op basic SRCS assign_op.cc DEPS ${op_DEPS})
add_operator(affine_channel_op basic SRCS affine_channel_op.cc DEPS ${op_DEPS})
add_operator(anchor_generator_op basic SRCS anchor_generator_op.cc DEPS ${op_DEPS})
add_operator(generate_proposals_op basic SRCS generate_proposals_op.cc DEPS ${op_DEPS})
add_operator(roi_align_op basic SRCS roi_align_op.cc DEPS ${op_DEPS})
add_operator(box_clip_op basic SRCS box_clip_op.cc DEPS ${op_DEPS})
add_operator(flatten_op basic SRCS flatten_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_range_abs_max_op basic SRCS fake_quantize_range_abs_max.cc DEPS ${op_DEPS})
......@@ -91,10 +96,6 @@ add_operator(write_to_array_op extra SRCS write_to_array_op.cc DEPS ${op_DEPS})
add_operator(topk_op extra SRCS topk_op.cc DEPS ${op_DEPS})
add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS})
add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS})
add_operator(affine_channel_op extra SRCS affine_channel_op.cc DEPS ${op_DEPS})
add_operator(anchor_generator_op extra SRCS anchor_generator_op.cc DEPS ${op_DEPS})
add_operator(generate_proposals_op extra SRCS generate_proposals_op.cc DEPS ${op_DEPS})
add_operator(roi_align_op extra SRCS roi_align_op.cc DEPS ${op_DEPS})
if (NOT LITE_WITH_X86)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/box_clip_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool BoxClipOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.Input);
CHECK_OR_FALSE(param_.ImInfo);
CHECK_OR_FALSE(param_.Output);
auto input_dims = param_.Input->dims();
auto im_info_dims = param_.ImInfo->dims();
auto input_box_size = input_dims.size();
CHECK_OR_FALSE(input_dims[input_box_size - 1] == 4);
CHECK_OR_FALSE(im_info_dims.size() == 2);
CHECK_OR_FALSE(im_info_dims[1] == 3);
return true;
}
bool BoxClipOpLite::InferShape() const {
auto* input = param_.Input;
auto* output = param_.Output;
output->Resize(input->dims());
output->set_lod(input->lod());
return true;
}
bool BoxClipOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
auto input = op_desc.Input("Input").front();
auto im_info = op_desc.Input("ImInfo").front();
auto output = op_desc.Output("Output").front();
param_.Input = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.ImInfo = scope->FindVar(im_info)->GetMutable<lite::Tensor>();
param_.Output = scope->FindVar(output)->GetMutable<lite::Tensor>();
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(box_clip, paddle::lite::operators::BoxClipOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class BoxClipOpLite : public OpLite {
public:
BoxClipOpLite() {}
explicit BoxClipOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "box clip"; }
private:
mutable BoxClipParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -21,34 +21,79 @@ namespace operators {
bool BoxCoderOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.prior_box);
CHECK_OR_FALSE(param_.prior_box_var);
CHECK_OR_FALSE(param_.target_box);
CHECK_OR_FALSE(param_.proposals);
auto prior_box_dims = param_.prior_box->dims();
CHECK_OR_FALSE(prior_box_dims.size() == 2);
CHECK_OR_FALSE(prior_box_dims[1] == 4);
if (param_.prior_box_var != nullptr) {
auto box_var_dim = param_.prior_box_var->dims();
CHECK_OR_FALSE(box_var_dim.size() == 2);
CHECK_OR_FALSE(box_var_dim == prior_box_dims);
}
return true;
}
bool BoxCoderOpLite::InferShape() const {
param_.proposals->Resize(param_.target_box->dims());
auto prior_box_dims = param_.prior_box->dims();
auto target_box_dims = param_.target_box->dims();
std::string code_type = param_.code_type;
int axis = param_.axis;
CHECK_OR_FALSE(code_type == "encode_center_size" ||
code_type == "decode_center_size");
if (code_type == "encode_center_size") {
CHECK_OR_FALSE(target_box_dims.size() == 2);
CHECK_OR_FALSE(target_box_dims[1] == 4);
param_.proposals->Resize({target_box_dims[0], prior_box_dims[0], 4});
} else if (code_type == "decode_center_size") {
CHECK_OR_FALSE(target_box_dims.size() == 3);
CHECK_OR_FALSE(axis == 0 || axis == 1);
if (axis == 0) {
CHECK_OR_FALSE(target_box_dims[1] == prior_box_dims[0]);
} else if (axis == 1) {
CHECK_OR_FALSE(target_box_dims[0] == prior_box_dims[0]);
}
CHECK_OR_FALSE(target_box_dims[2] == prior_box_dims[1]);
param_.proposals->Resize(target_box_dims);
}
if (code_type == "decode_center_size" && axis == 1) {
param_.proposals->set_lod(param_.prior_box->lod());
} else {
param_.proposals->set_lod(param_.target_box->lod());
}
return true;
}
bool BoxCoderOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
LOG(INFO) << "Attach Impl succeed!";
auto Prior_box_name = opdesc.Input("PriorBox").front();
auto Prior_box_var_name = opdesc.Input("PriorBoxVar").front();
auto Target_box_name = opdesc.Input("TargetBox").front();
auto Output_box_name = opdesc.Output("OutputBox").front();
param_.prior_box = GetVar<lite::Tensor>(scope, Prior_box_name);
param_.prior_box_var = GetVar<lite::Tensor>(scope, Prior_box_var_name);
param_.target_box = GetVar<lite::Tensor>(scope, Target_box_name);
param_.proposals = GetMutableVar<lite::Tensor>(scope, Output_box_name);
if (opdesc.HasAttr("axis")) {
param_.axis = opdesc.GetAttr<int>("axis");
// optional params
std::vector<std::string> input_arg_names = opdesc.InputArgumentNames();
if (std::find(input_arg_names.begin(),
input_arg_names.end(),
"PriorBoxVar") != input_arg_names.end()) {
auto box_var_arguments = opdesc.Input("PriorBoxVar");
if (box_var_arguments.size() > 0) {
auto* box_var = scope->FindVar(box_var_arguments.front());
if (box_var != nullptr) {
param_.prior_box_var = box_var->GetMutable<Tensor>();
}
}
}
param_.box_normalized = opdesc.GetAttr<bool>("box_normalized");
param_.code_type = opdesc.GetAttr<std::string>("code_type");
LOG(INFO) << "Attach Impl exit!";
param_.box_normalized = opdesc.GetAttr<bool>("box_normalized");
param_.axis = opdesc.GetAttr<int>("axis");
if (opdesc.HasAttr("variance")) {
param_.variance = opdesc.GetAttr<std::vector<float>>("variance");
}
return true;
}
......
......@@ -490,10 +490,11 @@ struct BoxCoderParam {
const lite::Tensor* prior_box_var{};
const lite::Tensor* target_box{};
lite::Tensor* proposals{};
int axis{0};
bool box_normalized{true};
// code_type: encode_center_size and decode_center_size
std::string code_type;
std::string code_type{"encode_center_size"};
bool box_normalized{true};
int axis{0};
std::vector<float> variance{};
};
/// ----------------------- multiclass_nms operators ----------------------
......@@ -782,6 +783,7 @@ struct AssignParam {
const lite::Tensor* X{};
lite::Tensor* Out{};
};
struct RoiAlignParam {
lite::Tensor* X{};
lite::Tensor* ROIs{};
......@@ -792,6 +794,12 @@ struct RoiAlignParam {
int sampling_ratio{-1};
};
struct BoxClipParam {
const lite::Tensor* Input{};
const lite::Tensor* ImInfo{};
lite::Tensor* Output{};
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -50,23 +50,20 @@ bool SequenceExpandOp::InferShape() const {
const auto y_lod = param_.Y->lod()[ref_level];
auto out_dims = param_.X->dims();
int64_t out_first_dim = 0;
if (x_lod.size() > 0) {
if (y_lod.size() <= 1) {
out_first_dim = x_dims[0];
} else {
for (int i = 1; i < y_lod.size(); ++i) {
int64_t x_seq_len = 1;
if (x_lod.size() == 1) {
x_seq_len = x_lod[0][i] - x_lod[0][i - 1];
}
out_first_dim += (y_lod[i] - y_lod[i - 1]) * x_seq_len;
if (y_lod.size() <= 1) {
out_first_dim = x_dims[0];
} else {
for (int i = 1; i < y_lod.size(); ++i) {
int64_t x_seq_len = 1;
if (x_lod.size() == 1) {
x_seq_len = x_lod[0][i] - x_lod[0][i - 1];
}
out_dims[0] = out_first_dim;
out_first_dim += (y_lod[i] - y_lod[i - 1]) * x_seq_len;
}
} else {
out_dims[0] = -1;
out_dims[0] = out_first_dim;
}
param_.Out->Resize(out_dims);
param_.Out->set_lod(x_lod);
return true;
}
......@@ -79,10 +76,6 @@ bool SequenceExpandOp::AttachImpl(const cpp::OpDesc &opdesc,
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.ref_level = opdesc.GetAttr<int>("ref_level");
CHECK(param_.X);
CHECK(param_.Y);
CHECK(param_.Out);
return true;
}
......
......@@ -23,6 +23,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
#lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_box_clip_compute SRCS box_clip_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
class BoxClipComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "input";
std::string im_info_ = "im_info";
std::string output_ = "output";
DDim input_dims_{};
LoD input_lod_{};
DDim im_info_dim_{};
public:
BoxClipComputeTester(const Place& place, const std::string& alias)
: TestCase(place, alias) {
input_dims_.ConstructFrom(std::vector<int64_t>({4, 3, 4}));
std::vector<uint64_t> lod0 = {0, 1, 4};
input_lod_.push_back(lod0);
im_info_dim_.ConstructFrom(std::vector<int64_t>({2, 3}));
}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(input_dims_);
auto* out_lod = out->mutable_lod();
*out_lod = input_lod_;
auto* out_data = out->mutable_data<float>();
auto* input = scope->FindTensor(input_);
const auto* input_data = input->data<float>();
for (int i = 0; i < 12; i++) {
out_data[i] = std::max(std::min(input_data[i], 9.f), 0.f);
}
for (int i = 12; i < 48; i++) {
out_data[i] = std::max(std::min(input_data[i], 14.f), 0.f);
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("box_clip");
op_desc->SetInput("Input", {input_});
op_desc->SetInput("ImInfo", {im_info_});
op_desc->SetOutput("Output", {output_});
}
void PrepareData() override {
std::vector<float> input_data(input_dims_.production());
for (int i = 0; i < input_dims_.production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
input_data[i] = sign * static_cast<float>((i * 7) % 20);
}
SetCommonTensor(input_, input_dims_, input_data.data());
auto input_tensor = baseline_scope()->FindMutableTensor(input_);
input_tensor->set_lod(input_lod_);
std::vector<float> im_info_data{10, 10, 1, 15, 15, 1};
SetCommonTensor(im_info_, im_info_dim_, im_info_data.data());
}
};
TEST(Boxclip, precision) {
LOG(INFO) << "test box_clip op";
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
std::unique_ptr<arena::TestCase> tester(
new BoxClipComputeTester(place, "def"));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册