未验证 提交 77b3062a 编写于 作者: H HappyAngel 提交者: GitHub

Merge pull request #127 from PaddlePaddle/develop

pull code
......@@ -10,3 +10,6 @@
[submodule "third-party/protobuf-host"]
path = third-party/protobuf-host
url = https://github.com/protocolbuffers/protobuf.git
[submodule "third-party/flatbuffers"]
path = third-party/flatbuffers
url = https://github.com/google/flatbuffers.git
......@@ -168,6 +168,7 @@ if(LITE_WITH_RKNPU)
include(device/rknpu)
endif()
include(external/flatbuffers)
# for mobile
if (WITH_LITE AND LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
# Introduce variables:
# * CMAKE_INSTALL_LIBDIR
INCLUDE(GNUInstallDirs)
SET(LIBDIR "lib")
if(CMAKE_INSTALL_LIBDIR MATCHES ".*lib64$")
SET(LIBDIR "lib64")
endif()
SET(FLATBUFFERS_SOURCES_DIR ${CMAKE_SOURCE_DIR}/third-party/flatbuffers)
SET(FLATBUFFERS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flatbuffers)
SET(FLATBUFFERS_INCLUDE_DIR "${FLATBUFFERS_INSTALL_DIR}/include" CACHE PATH "flatbuffers include directory." FORCE)
IF(WIN32)
set(FLATBUFFERS_LIBRARIES "${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/libflatbuffers.lib" CACHE FILEPATH "FLATBUFFERS_LIBRARIES" FORCE)
ELSE(WIN32)
set(FLATBUFFERS_LIBRARIES "${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/libflatbuffers.a" CACHE FILEPATH "FLATBUFFERS_LIBRARIES" FORCE)
ENDIF(WIN32)
INCLUDE_DIRECTORIES(${FLATBUFFERS_INCLUDE_DIR})
if(NOT HOST_CXX_COMPILER)
set(HOST_CXX_COMPILER ${CMAKE_CXX_COMPILER})
set(HOST_C_COMPILER ${CMAKE_C_COMPILER})
endif()
SET(OPTIONAL_ARGS "-DCMAKE_CXX_COMPILER=${HOST_CXX_COMPILER}"
"-DCMAKE_C_COMPILER=${HOST_C_COMPILER}")
ExternalProject_Add(
extern_flatbuffers
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/google/flatbuffers.git"
GIT_TAG "v1.12.0"
SOURCE_DIR ${FLATBUFFERS_SOURCES_DIR}
PREFIX ${FLATBUFFERS_INCLUDE_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DBUILD_STATIC_LIBS=ON
-DCMAKE_INSTALL_PREFIX=${FLATBUFFERS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_TESTING=OFF
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${CROSS_COMPILE_CMAKE_ARGS}
${OPTIONAL_ARGS}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${FLATBUFFERS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
)
IF(WIN32)
IF(NOT EXISTS "${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/libflatbuffers.lib")
add_custom_command(TARGET extern_flatbuffers POST_BUILD
COMMAND cmake -E copy ${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/flatbuffers_static.lib ${FLATBUFFERS_INSTALL_DIR}/${LIBDIR}/libflatbuffers.lib
)
ENDIF()
ENDIF(WIN32)
ADD_LIBRARY(flatbuffers STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET flatbuffers PROPERTY IMPORTED_LOCATION ${FLATBUFFERS_LIBRARIES})
ADD_DEPENDENCIES(flatbuffers extern_flatbuffers)
SET(FLATBUFFERS_FLATC_EXECUTABLE ${FLATBUFFERS_INSTALL_DIR}/bin/flatc)
function(register_generated_output file_name)
get_property(tmp GLOBAL PROPERTY FBS_GENERATED_OUTPUTS)
list(APPEND tmp ${file_name})
set_property(GLOBAL PROPERTY FBS_GENERATED_OUTPUTS ${tmp})
endfunction(register_generated_output)
function(compile_flatbuffers_schema_to_cpp_opt TARGET SRC_FBS OPT)
if(FLATBUFFERS_BUILD_LEGACY)
set(OPT ${OPT};--cpp-std c++0x)
else()
# --cpp-std is defined by flatc default settings.
endif()
message(STATUS "`${SRC_FBS}`: add generation of C++ code with '${OPT}'")
get_filename_component(SRC_FBS_DIR ${SRC_FBS} PATH)
message(STATUS "SRC_FBS_DIR: ${SRC_FBS_DIR}")
string(REGEX REPLACE "\\.fbs$" "_generated.h" GEN_HEADER ${SRC_FBS})
add_custom_command(
OUTPUT ${GEN_HEADER}
COMMAND "${FLATBUFFERS_FLATC_EXECUTABLE}"
--cpp --gen-mutable --gen-object-api --reflect-names
--cpp-ptr-type flatbuffers::unique_ptr # Used to test with C++98 STLs
${OPT}
-I "${CMAKE_CURRENT_SOURCE_DIR}/tests/include_test"
-o "${SRC_FBS_DIR}"
"${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS}"
DEPENDS flatbuffers
COMMENT "Run generation: '${GEN_HEADER}'")
include_directories(${FLATBUFFERS_INCLUDE_DIR})
register_generated_output(${GEN_HEADER})
add_custom_target(${TARGET} ALL DEPENDS ${GEN_HEADER})
endfunction()
set(FRAMEWORK_FBS_DIR "lite/model_parser/flatbuffers")
set(FRAMEWORK_SCHEMA_PATH "${FRAMEWORK_FBS_DIR}/framework.fbs")
compile_flatbuffers_schema_to_cpp_opt(framework_fbs_header ${FRAMEWORK_SCHEMA_PATH} "--no-includes;--gen-compare;--force-empty")
# C++ Train Demo
# Introduction
我们都知道,PaddleLite可以做移动端预测,事实上PaddleLite支持在移动端做模型训练。本文给出使用PaddleLite做训练的例子,这一例子对应的任务是“波士顿房价预测”,又称作“fit-a-line”。
## Introduction
我们都知道,PaddleLite可以做移动端预测,事实上PaddleLite支持在移动端做模型训练。本文给出使用PaddleLite做训练的例子,这一例子对应的任务是“波士顿房价预测”,又称作“fit-a-line”。
你可以通过book库中的
你可以通过book库中的
[文档](https://paddlepaddle.org.cn/documentation/docs/zh/user_guides/simple_case/fit_a_line/README.cn.html)
[源码](https://github.com/PaddlePaddle/book/tree/develop/01.fit_a_line)
......@@ -10,18 +12,16 @@
其使用线性回归(Linear Regression)
模型做建模。本文主要介绍如何将其迁移至Paddle-Lite进行训练。
注:这是一篇使用C++ API做模型训练的教程,其他API暂时不支持训练功能。
# Requirements
## Requirements
- 一部安卓手机,用于运行训练程序
- 装了Paddle (version: 1.7.0) 的python
- 装了Paddle (version >= 1.7.0) 的python
# Quick start
## Quick start
## Step1 build paddle-lite
### Step1 build paddle-lite
请按照[paddle-lite官方文档](https://paddle-lite.readthedocs.io/zh/latest/user_guides/source_compile.html#paddlelite) 的教程编译full_publish的paddle-lite lib。以Linux上编译为例,其具体的命令为:
请按照paddle-lite官方文档的教程编译full_publish的paddle-lite lib。以Linux上编译为例,其具体的命令为:
```shell
## 配置环境
......@@ -51,7 +51,7 @@ cd Paddle-Lite
Paddle-Lite/build.lite.android.armv7.gcc/inference_lite_lib.android.armv7/cxx/lib/libpaddle_full_api_shared.so
```
## Step2 编译lr_trainer
### Step2 编译lr_trainer
```shell
cd Paddle-Lite/lite/demo/cxx/train_demo/cplus_train/
......@@ -64,7 +64,7 @@ bin/
`-- demo_trainer
```
## Step3 download model and run it!
### Step3 download model and run it!
在你的笔记本电脑上,用usb连接到手机,开启开发者模式,在任意目录下执行:
......@@ -102,7 +102,7 @@ sample 8: Loss: 248.445
sample 9: Loss: 325.135
```
# 更多细节
## 更多细节
上面提到的模型是直接下载得到的,如果你想自己生成,可以执行以下命令:
```shell
......@@ -125,9 +125,9 @@ md5sum fc_0.w_0: 2c7b3649b2a9cf7bcd19f8b256ce795d
如果你想生成自己的模型用于训练,可以参考`train.py`中保存模型的方式。
# 与Paddle训练结果做校对
## 与Paddle训练结果做校对
## 前10个Loss值
### 前10个Loss值
为了验证paddle与lite的一致性,我们控制模型参数一致、数据一致、batch size = 1的情况下,训练10个batch, 记录了二者的loss值。
......@@ -171,11 +171,11 @@ sample 8: Loss: 248.445
sample 9: Loss: 325.135
```
## Loss 曲线
### Loss 曲线
控制训练时的batch size为20,每个epoch对训练数据做全局shuffle,训练100个epoch后,paddle和lite的loss曲线对比如下。
![lr_loss](image/lr_loss.png)
![lr_loss](../images/lr_loss.png)
如果想复现上述效果,paddle+python的运行命令为:
......
......@@ -2,7 +2,7 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK OR (NOT LITE_WITH_LOG))
lite_cc_library(place SRCS paddle_place.cc DEPS logging)
else()
lite_cc_library(place SRCS paddle_place.cc DEPS glog)
endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
endif()
if (LITE_ON_TINY_PUBLISH)
set(CMAKE_CXX_FLAGS_RELEASE "-Os -DNDEBUG")
......@@ -70,7 +70,7 @@ else()
set(TARGET_COMIPILE_FLAGS "${TARGET_COMIPILE_FLAGS} -flto")
endif()
set_target_properties(paddle_light_api_shared PROPERTIES COMPILE_FLAGS "${TARGET_COMIPILE_FLAGS}")
add_dependencies(paddle_light_api_shared op_list_h kernel_list_h)
add_dependencies(paddle_light_api_shared op_list_h kernel_list_h framework_fbs_header)
if (LITE_WITH_NPU)
# Need to add HIAI runtime libs (libhiai.so) dependency
target_link_libraries(paddle_light_api_shared ${npu_builder_libs} ${npu_runtime_libs})
......
......@@ -326,10 +326,8 @@ void Predictor::Build(const std::shared_ptr<cpp::ProgramDesc> &desc,
}
}
if (is_quantized_model) {
#ifdef LITE_WITH_ARM
inner_places.insert(inner_places.begin(),
Place{TARGET(kARM), PRECISION(kInt8)});
#endif
}
Program program(*desc.get(), scope_, inner_places);
......
......@@ -66,7 +66,7 @@ class KernelBase {
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = std::string("NotImpl");
#ifdef LITE_WITH_ARM
#ifdef LITE_WITH_OPENCL
ch->cl_event = event_;
#endif
}
......
......@@ -42,6 +42,7 @@ add_kernel(cast_compute_arm ARM basic SRCS cast_compute.cc DEPS ${lite_kernel_de
add_kernel(reduce_mean_compute_arm ARM basic SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(stack_compute_arm ARM basic SRCS stack_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(affine_channel_compute_arm ARM basic SRCS affine_channel_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(affine_grid_compute_arm ARM basic SRCS affine_grid_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(range_compute_arm ARM basic SRCS range_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(dropout_compute_arm ARM basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(layout_compute_arm ARM basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/affine_grid_compute.h"
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/backends/arm/math/sgemm.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void AffineGridCompute::PrepareForRun() {
auto& param = Param<operators::AffineGridParam>();
auto& ctx = this->ctx_->template As<ARMContext>();
const lite::Tensor* x = param.X;
const float* din = x->data<float>();
lite::Tensor* out = param.Out;
float* dout = param.Out->mutable_data<float>();
int N = x->dims()[0];
int H = param.output_shape[2];
int W = param.output_shape[3];
vh = reinterpret_cast<float*>(malloc(sizeof(float) * H));
vw = reinterpret_cast<float*>(malloc(sizeof(float) * W));
int out_size = H * W * 3;
float scale = 2 / (static_cast<float>(H) - 1);
for (int i = 0; i < H; i++) {
vh[i] = -1 + scale * i;
}
scale = 2 / (static_cast<float>(W) - 1);
for (int i = 0; i < W; i++) {
vw[i] = -1 + i * scale;
}
return;
}
void AffineGridCompute::Run() {
auto& param = Param<operators::AffineGridParam>();
auto& ctx = this->ctx_->template As<ARMContext>();
const lite::Tensor* x = param.X;
int N = x->dims()[0];
int H = param.output_shape[2];
int W = param.output_shape[3];
int out_size = H * W * 3;
float* hw3 = ctx.workspace_data<float>() + ctx.llc_size() / sizeof(float);
for (int i = 0; i < out_size; i += 3) {
hw3[i] = 1;
hw3[i + 1] = 1;
hw3[i + 2] = 1;
}
for (int i = 0; i < H * W; i++) {
hw3[i * 3 + 1] = vh[i / W];
}
for (int i = 0; i < H * W; i++) {
hw3[i * 3] = vw[i % W];
}
const float* din = x->data<float>();
float* dout = param.Out->mutable_data<float>();
float* tmp = dout;
operators::ActivationParam act_param;
act_param.has_active = false;
for (int i = 0; i < N; i++) {
lite::arm::math::sgemm(false,
true,
H * W,
2,
3,
1.f,
hw3,
3,
din,
3,
0.f,
dout,
2,
nullptr,
false,
act_param,
&ctx);
din += 6;
dout += H * W * 2;
}
return;
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(affine_grid,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::AffineGridCompute,
def)
.BindInput("Theta", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class AffineGridCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::AffineGridParam;
void PrepareForRun() override;
void Run() override;
virtual ~AffineGridCompute() = default;
float* vh;
float* vw;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -88,7 +88,7 @@ void SequenceConvCompute::Run() {
paddle::lite::arm::math::im2col(
sub_in_data,
1,
sequence_len,
input_row_end - input_row_begin,
hidden_dim, // C H W -> 1, seq_len, hidden_dim
kernel_size,
hidden_dim, // kernel_h, kernel_w
......
......@@ -16,6 +16,7 @@ add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${li
add_kernel(write_to_array_compute_host Host extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps})
add_kernel(read_from_array_compute_host Host extra SRCS read_from_array_compute.cc DEPS ${lite_kernel_deps})
add_kernel(assign_compute_host Host extra SRCS assign_compute.cc DEPS ${lite_kernel_deps})
add_kernel(retinanet_detection_output_compute_host Host extra SRCS retinanet_detection_output_compute.cc DEPS ${lite_kernel_deps})
add_kernel(where_index_compute_host Host extra SRCS where_index_compute.cc DEPS ${lite_kernel_deps})
if(LITE_BUILD_EXTRA)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/retinanet_detection_output_compute.h"
#include <cmath>
#include <map>
#include <utility>
#include <vector>
#include "lite/operators/retinanet_detection_output_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
template <class T>
bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
}
template <class T>
bool SortScoreTwoPairDescend(const std::pair<float, std::pair<T, T>>& pair1,
const std::pair<float, std::pair<T, T>>& pair2) {
return pair1.first > pair2.first;
}
template <class T>
static inline void GetMaxScoreIndex(
const std::vector<T>& scores,
const T threshold,
int top_k,
std::vector<std::pair<T, int>>* sorted_indices) {
for (size_t i = 0; i < scores.size(); ++i) {
if (scores[i] > threshold) {
sorted_indices->push_back(std::make_pair(scores[i], i));
}
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(),
sorted_indices->end(),
SortScorePairDescend<int>);
// Keep top_k scores if needed.
if (top_k > -1 && top_k < static_cast<int>(sorted_indices->size())) {
sorted_indices->resize(top_k);
}
}
template <class T>
static inline T BBoxArea(const std::vector<T>& box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static inline T JaccardOverlap(const std::vector<T>& box1,
const std::vector<T>& box2,
const bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
T norm = normalized ? static_cast<T>(0.) : static_cast<T>(1.);
T inter_w = inter_xmax - inter_xmin + norm;
T inter_h = inter_ymax - inter_ymin + norm;
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <class T>
void NMSFast(const std::vector<std::vector<T>>& cls_dets,
const T nms_threshold,
const T eta,
std::vector<int>* selected_indices) {
int64_t num_boxes = cls_dets.size();
std::vector<std::pair<T, int>> sorted_indices;
for (int64_t i = 0; i < num_boxes; ++i) {
sorted_indices.push_back(std::make_pair(cls_dets[i][4], i));
}
// Sort the score pair according to the scores in descending order
std::stable_sort(
sorted_indices.begin(), sorted_indices.end(), SortScorePairDescend<int>);
selected_indices->clear();
T adaptive_threshold = nms_threshold;
while (sorted_indices.size() != 0) {
const int idx = sorted_indices.front().second;
bool keep = true;
for (size_t k = 0; k < selected_indices->size(); ++k) {
if (keep) {
const int kept_idx = (*selected_indices)[k];
T overlap = T(0.);
overlap = JaccardOverlap<T>(cls_dets[idx], cls_dets[kept_idx], false);
keep = overlap <= adaptive_threshold;
} else {
break;
}
}
if (keep) {
selected_indices->push_back(idx);
}
sorted_indices.erase(sorted_indices.begin());
if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
}
template <class T>
void DeltaScoreToPrediction(
const std::vector<T>& bboxes_data,
const std::vector<T>& anchors_data,
T im_height,
T im_width,
T im_scale,
int class_num,
const std::vector<std::pair<T, int>>& sorted_indices,
std::map<int, std::vector<std::vector<T>>>* preds) {
im_height = static_cast<T>(std::round(im_height / im_scale));
im_width = static_cast<T>(std::round(im_width / im_scale));
T zero(0);
int i = 0;
for (const auto& it : sorted_indices) {
T score = it.first;
int idx = it.second;
int a = idx / class_num;
int c = idx % class_num;
int box_offset = a * 4;
T anchor_box_width =
anchors_data[box_offset + 2] - anchors_data[box_offset] + 1;
T anchor_box_height =
anchors_data[box_offset + 3] - anchors_data[box_offset + 1] + 1;
T anchor_box_center_x = anchors_data[box_offset] + anchor_box_width / 2;
T anchor_box_center_y =
anchors_data[box_offset + 1] + anchor_box_height / 2;
T target_box_center_x = 0, target_box_center_y = 0;
T target_box_width = 0, target_box_height = 0;
target_box_center_x =
bboxes_data[box_offset] * anchor_box_width + anchor_box_center_x;
target_box_center_y =
bboxes_data[box_offset + 1] * anchor_box_height + anchor_box_center_y;
target_box_width = std::exp(bboxes_data[box_offset + 2]) * anchor_box_width;
target_box_height =
std::exp(bboxes_data[box_offset + 3]) * anchor_box_height;
T pred_box_xmin = target_box_center_x - target_box_width / 2;
T pred_box_ymin = target_box_center_y - target_box_height / 2;
T pred_box_xmax = target_box_center_x + target_box_width / 2 - 1;
T pred_box_ymax = target_box_center_y + target_box_height / 2 - 1;
pred_box_xmin = pred_box_xmin / im_scale;
pred_box_ymin = pred_box_ymin / im_scale;
pred_box_xmax = pred_box_xmax / im_scale;
pred_box_ymax = pred_box_ymax / im_scale;
pred_box_xmin = std::max(std::min(pred_box_xmin, im_width - 1), zero);
pred_box_ymin = std::max(std::min(pred_box_ymin, im_height - 1), zero);
pred_box_xmax = std::max(std::min(pred_box_xmax, im_width - 1), zero);
pred_box_ymax = std::max(std::min(pred_box_ymax, im_height - 1), zero);
std::vector<T> one_pred;
one_pred.push_back(pred_box_xmin);
one_pred.push_back(pred_box_ymin);
one_pred.push_back(pred_box_xmax);
one_pred.push_back(pred_box_ymax);
one_pred.push_back(score);
(*preds)[c].push_back(one_pred);
i++;
}
}
template <class T>
void MultiClassNMS(const std::map<int, std::vector<std::vector<T>>>& preds,
int class_num,
const int keep_top_k,
const T nms_threshold,
const T nms_eta,
std::vector<std::vector<T>>* nmsed_out,
int* num_nmsed_out) {
std::map<int, std::vector<int>> indices;
int num_det = 0;
for (int c = 0; c < class_num; ++c) {
if (static_cast<bool>(preds.count(c))) {
const std::vector<std::vector<T>> cls_dets = preds.at(c);
NMSFast(cls_dets, nms_threshold, nms_eta, &(indices[c]));
num_det += indices[c].size();
}
}
std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
for (const auto& it : indices) {
int label = it.first;
const std::vector<int>& label_indices = it.second;
for (size_t j = 0; j < label_indices.size(); ++j) {
int idx = label_indices[j];
score_index_pairs.push_back(
std::make_pair(preds.at(label)[idx][4], std::make_pair(label, idx)));
}
}
// Keep top k results per image.
std::stable_sort(score_index_pairs.begin(),
score_index_pairs.end(),
SortScoreTwoPairDescend<int>);
if (num_det > keep_top_k) {
score_index_pairs.resize(keep_top_k);
}
// Store the new indices.
std::map<int, std::vector<int>> new_indices;
for (const auto& it : score_index_pairs) {
int label = it.second.first;
int idx = it.second.second;
std::vector<T> one_pred;
one_pred.push_back(label);
one_pred.push_back(preds.at(label)[idx][4]);
one_pred.push_back(preds.at(label)[idx][0]);
one_pred.push_back(preds.at(label)[idx][1]);
one_pred.push_back(preds.at(label)[idx][2]);
one_pred.push_back(preds.at(label)[idx][3]);
nmsed_out->push_back(one_pred);
}
*num_nmsed_out = (num_det > keep_top_k ? keep_top_k : num_det);
}
template <class T>
void RetinanetDetectionOutput(
const operators::RetinanetDetectionOutputParam& param,
const std::vector<Tensor>& scores,
const std::vector<Tensor>& bboxes,
const std::vector<Tensor>& anchors,
const Tensor& im_info,
std::vector<std::vector<T>>* nmsed_out,
int* num_nmsed_out) {
int64_t nms_top_k = param.nms_top_k;
int64_t keep_top_k = param.keep_top_k;
T nms_threshold = static_cast<T>(param.nms_threshold);
T nms_eta = static_cast<T>(param.nms_eta);
T score_threshold = static_cast<T>(param.score_threshold);
int64_t class_num = scores[0].dims()[1];
std::map<int, std::vector<std::vector<T>>> preds;
for (size_t l = 0; l < scores.size(); ++l) {
// Fetch per level score
Tensor scores_per_level = scores[l];
// Fetch per level bbox
Tensor bboxes_per_level = bboxes[l];
// Fetch per level anchor
Tensor anchors_per_level = anchors[l];
int64_t scores_num = scores_per_level.numel();
int64_t bboxes_num = bboxes_per_level.numel();
std::vector<T> scores_data(scores_num);
std::vector<T> bboxes_data(bboxes_num);
std::vector<T> anchors_data(bboxes_num);
std::copy_n(scores_per_level.data<T>(), scores_num, scores_data.begin());
std::copy_n(bboxes_per_level.data<T>(), bboxes_num, bboxes_data.begin());
std::copy_n(anchors_per_level.data<T>(), bboxes_num, anchors_data.begin());
std::vector<std::pair<T, int>> sorted_indices;
// For the highest level, we take the threshold 0.0
T threshold = (l < (scores.size() - 1) ? score_threshold : 0.0);
GetMaxScoreIndex(scores_data, threshold, nms_top_k, &sorted_indices);
auto* im_info_data = im_info.data<T>();
auto im_height = im_info_data[0];
auto im_width = im_info_data[1];
auto im_scale = im_info_data[2];
DeltaScoreToPrediction(bboxes_data,
anchors_data,
im_height,
im_width,
im_scale,
class_num,
sorted_indices,
&preds);
}
MultiClassNMS(preds,
class_num,
keep_top_k,
nms_threshold,
nms_eta,
nmsed_out,
num_nmsed_out);
}
template <class T>
void MultiClassOutput(const std::vector<std::vector<T>>& nmsed_out,
Tensor* outs) {
auto* odata = outs->mutable_data<T>();
int count = 0;
int64_t out_dim = 6;
for (size_t i = 0; i < nmsed_out.size(); ++i) {
odata[count * out_dim] = nmsed_out[i][0] + 1; // label
odata[count * out_dim + 1] = nmsed_out[i][1]; // score
odata[count * out_dim + 2] = nmsed_out[i][2]; // xmin
odata[count * out_dim + 3] = nmsed_out[i][3]; // xmin
odata[count * out_dim + 4] = nmsed_out[i][4]; // xmin
odata[count * out_dim + 5] = nmsed_out[i][5]; // xmin
count++;
}
}
void RetinanetDetectionOutputCompute::Run() {
auto& param = Param<operators::RetinanetDetectionOutputParam>();
auto& boxes = param.bboxes;
auto& scores = param.scores;
auto& anchors = param.anchors;
auto* im_info = param.im_info;
auto* outs = param.out;
std::vector<Tensor> boxes_list(boxes.size());
std::vector<Tensor> scores_list(scores.size());
std::vector<Tensor> anchors_list(anchors.size());
for (size_t j = 0; j < boxes_list.size(); ++j) {
boxes_list[j] = *boxes[j];
scores_list[j] = *scores[j];
anchors_list[j] = *anchors[j];
}
auto score_dims = scores_list[0].dims();
int64_t batch_size = score_dims[0];
auto box_dims = boxes_list[0].dims();
int64_t box_dim = box_dims[2];
int64_t out_dim = box_dim + 2;
std::vector<std::vector<std::vector<float>>> all_nmsed_out;
std::vector<uint64_t> batch_starts = {0};
for (int i = 0; i < batch_size; ++i) {
int num_nmsed_out = 0;
std::vector<Tensor> box_per_batch_list(boxes_list.size());
std::vector<Tensor> score_per_batch_list(scores_list.size());
for (size_t j = 0; j < boxes_list.size(); ++j) {
auto score_dims = scores_list[j].dims();
score_per_batch_list[j] = scores_list[j].Slice<float>(i, i + 1);
score_per_batch_list[j].Resize({score_dims[1], score_dims[2]});
box_per_batch_list[j] = boxes_list[j].Slice<float>(i, i + 1);
box_per_batch_list[j].Resize({score_dims[1], box_dim});
}
Tensor im_info_slice = im_info->Slice<float>(i, i + 1);
std::vector<std::vector<float>> nmsed_out;
RetinanetDetectionOutput(param,
score_per_batch_list,
box_per_batch_list,
anchors_list,
im_info_slice,
&nmsed_out,
&num_nmsed_out);
all_nmsed_out.push_back(nmsed_out);
batch_starts.push_back(batch_starts.back() + num_nmsed_out);
}
uint64_t num_kept = batch_starts.back();
if (num_kept == 0) {
outs->Resize({0, out_dim});
} else {
outs->Resize({static_cast<int64_t>(num_kept), out_dim});
for (int i = 0; i < batch_size; ++i) {
int64_t s = static_cast<int64_t>(batch_starts[i]);
int64_t e = static_cast<int64_t>(batch_starts[i + 1]);
if (e > s) {
Tensor out = outs->Slice<float>(s, e);
MultiClassOutput(all_nmsed_out[i], &out);
}
}
}
LoD lod;
lod.emplace_back(batch_starts);
outs->set_lod(lod);
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
retinanet_detection_output,
kHost,
kFloat,
kNCHW,
paddle::lite::kernels::host::RetinanetDetectionOutputCompute,
def)
.BindInput("BBoxes",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("Scores",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("Anchors",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("ImInfo",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class RetinanetDetectionOutputCompute
: public KernelLite<TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
public:
void Run() override;
virtual ~RetinanetDetectionOutputCompute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -85,6 +85,9 @@ void ConvImageCompute::PrepareForRun() {
<< paddings[2] << " " << paddings[3];
CHECK(pad_equal && stride_equal && dilation_equal);
if (!is_mali) {
use_turn_ = false;
}
// general gws..
auto out_image_shape = InitImageDimInfoWith(output_dims);
......
......@@ -11,9 +11,9 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <random>
#include <gtest/gtest.h>
#include <random>
#include "lite/backends/opencl/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
......
......@@ -3,6 +3,7 @@ if (NOT LITE_ON_TINY_PUBLISH)
endif()
add_subdirectory(cpp)
add_subdirectory(naive_buffer)
add_subdirectory(flatbuffers)
#lite_cc_library(runtime_lite SRCS runtime.cc)
......
function(lite_fbs_library TARGET)
set(multiValueArgs SRCS FBS_DEPS)
cmake_parse_arguments(args "" "" "${multiValueArgs}" ${ARGN})
lite_cc_library(${TARGET} SRCS ${args_SRCS})
add_dependencies(${TARGET} ${args_FBS_DEPS})
endfunction()
lite_fbs_library(fbs_op_desc SRCS op_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_program_desc SRCS program_desc.cc FBS_DEPS framework_fbs_header)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/block_desc.h"
namespace paddle {
namespace lite {
namespace fbs {
template <>
proto::VarDesc* BlockDesc::GetVar<proto::VarDesc>(int32_t idx) {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return const_cast<proto::VarDesc*>(desc_->vars()->Get(idx));
}
template <>
proto::OpDesc* BlockDesc::GetOp<proto::OpDesc>(int32_t idx) {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return const_cast<proto::OpDesc*>(desc_->ops()->Get(idx));
}
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/model_parser/base/block_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace fbs {
class BlockDesc : public BlockDescReadAPI {
public:
explicit BlockDesc(proto::BlockDesc* desc) : desc_(desc) { CHECK(desc_); }
int32_t Idx() const override { return desc_->idx(); }
int32_t ParentIdx() const override { return desc_->parent_idx(); }
size_t VarsSize() const override { return desc_->vars()->size(); }
template <typename T>
T* GetVar(int32_t idx);
template <typename T>
T const* GetVar(int32_t idx) const {
return GetVar<T>(idx);
}
size_t OpsSize() const override {
CHECK(desc_);
CHECK(desc_->ops());
return desc_->ops()->size();
}
template <typename T>
T* GetOp(int32_t idx);
template <typename T>
T const* GetOp(int32_t idx) const {
return GetOp<T>(idx);
}
int32_t ForwardBlockIdx() const override {
return desc_->forward_block_idx();
}
BlockDesc() = delete;
private:
proto::BlockDesc* desc_; // not_own
};
} // namespace fbs
} // namespace lite
} // namespace paddle
// Generated from framework.proto
namespace paddle.lite.fbs.proto;
enum AttrType : int {
INT = 0,
FLOAT = 1,
STRING = 2,
INTS = 3,
FLOATS = 4,
STRINGS = 5,
BOOLEAN = 6,
BOOLEANS = 7,
BLOCK = 8,
LONG = 9,
BLOCKS = 10,
LONGS = 11,
}
namespace paddle.lite.fbs.proto.VarType_;
enum Type : int {
BOOL = 0,
INT16 = 1,
INT32 = 2,
INT64 = 3,
FP16 = 4,
FP32 = 5,
FP64 = 6,
LOD_TENSOR = 7,
SELECTED_ROWS = 8,
FEED_MINIBATCH = 9,
FETCH_LIST = 10,
STEP_SCOPES = 11,
LOD_RANK_TABLE = 12,
LOD_TENSOR_ARRAY = 13,
PLACE_LIST = 14,
READER = 15,
RAW = 17,
TUPLE = 18,
SIZE_T = 19,
UINT8 = 20,
INT8 = 21,
}
namespace paddle.lite.fbs.proto.CompatibleInfo_;
enum Type : int {
COMPATIBLE = 0,
DEFINITELY_NOT = 1,
POSSIBLE = 2,
BUG_FIX = 3,
PRECISION_CHANGE = 4,
}
namespace paddle.lite.fbs.proto;
table Version {
version:long;
}
table OpDesc {
type:string (required);
inputs:[paddle.lite.fbs.proto.OpDesc_.Var];
outputs:[paddle.lite.fbs.proto.OpDesc_.Var];
attrs:[paddle.lite.fbs.proto.OpDesc_.Attr];
is_target:bool;
}
namespace paddle.lite.fbs.proto.OpDesc_;
table Attr {
name:string (required, key);
type:paddle.lite.fbs.proto.AttrType;
i:int;
f:float;
s:string;
ints:[int];
floats:[float];
strings:[string];
b:bool;
bools:[bool];
block_idx:int;
l:long;
blocks_idx:[int];
longs:[long];
}
table Var {
parameter:string (required, key);
arguments:[string];
}
namespace paddle.lite.fbs.proto;
table VarType {
type:paddle.lite.fbs.proto.VarType_.Type;
selected_rows:paddle.lite.fbs.proto.VarType_.TensorDesc;
lod_tensor:paddle.lite.fbs.proto.VarType_.LoDTensorDesc;
tensor_array:paddle.lite.fbs.proto.VarType_.LoDTensorArrayDesc;
reader:paddle.lite.fbs.proto.VarType_.ReaderDesc;
tuple:paddle.lite.fbs.proto.VarType_.Tuple;
}
namespace paddle.lite.fbs.proto.VarType_;
table TensorDesc {
data_type:paddle.lite.fbs.proto.VarType_.Type;
dims:[long];
}
table LoDTensorDesc {
tensor:paddle.lite.fbs.proto.VarType_.TensorDesc (required);
lod_level:int;
}
table LoDTensorArrayDesc {
tensor:paddle.lite.fbs.proto.VarType_.TensorDesc (required);
lod_level:int;
}
table ReaderDesc {
lod_tensor:[paddle.lite.fbs.proto.VarType_.LoDTensorDesc];
}
table Tuple {
element_type:[paddle.lite.fbs.proto.VarType_.Type];
}
namespace paddle.lite.fbs.proto;
table VarDesc {
name:string (required, key);
type:paddle.lite.fbs.proto.VarType (required);
persistable:bool;
need_check_feed:bool;
}
table BlockDesc {
idx:int;
parent_idx:int;
vars:[paddle.lite.fbs.proto.VarDesc];
ops:[paddle.lite.fbs.proto.OpDesc];
forward_block_idx:int = -1;
}
table CompatibleInfo {
version:string (required);
type:paddle.lite.fbs.proto.CompatibleInfo_.Type;
}
table OpCompatibleMap {
pair:[paddle.lite.fbs.proto.OpCompatibleMap_.OpCompatiblePair];
default_required_version:string;
}
namespace paddle.lite.fbs.proto.OpCompatibleMap_;
table OpCompatiblePair {
op_name:string (required, key);
compatible_info:paddle.lite.fbs.proto.CompatibleInfo (required);
}
namespace paddle.lite.fbs.proto;
table ProgramDesc {
blocks:[paddle.lite.fbs.proto.BlockDesc];
version:paddle.lite.fbs.proto.Version;
op_compatible_map:paddle.lite.fbs.proto.OpCompatibleMap;
}
root_type paddle.lite.fbs.proto.ProgramDesc;
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/op_desc.h"
namespace paddle {
namespace lite {
namespace fbs {
template <>
std::string OpDesc::GetAttr<std::string>(const std::string& name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str());
if (!it->s()) {
return std::string();
}
return it->s()->str();
}
template <>
std::string OpDesc::GetAttr<std::string>(size_t idx) const {
const auto& it = desc_->attrs()->Get(idx);
if (!it->s()) {
return std::string();
}
return it->s()->str();
}
template <>
std::vector<std::string> OpDesc::GetAttr<std::vector<std::string>>(
const std::string& name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str());
CHECK(it) << "Attr " << name << "does not exist.";
std::vector<std::string> res;
if (it->strings()) {
res.reserve(it->strings()->size());
for (const auto& v : *it->strings()) {
res.push_back(v->str());
}
}
return res;
}
template <>
std::vector<std::string> OpDesc::GetAttr<std::vector<std::string>>(
size_t idx) const {
const auto& it = desc_->attrs()->Get(idx);
CHECK(it) << "Attr " << idx << "does not exist.";
std::vector<std::string> res;
if (it->strings()) {
res.reserve(it->strings()->size());
for (const auto& v : *it->strings()) {
res.push_back(v->str());
}
}
return res;
}
#define GET_ATTR_IMPL(T, fb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
return it->fb_f__(); \
} \
template <> \
T OpDesc::GetAttr<T>(size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \
return it->fb_f__(); \
}
#define GET_ATTRS_IMPL(T, fb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
T res; \
res.reserve(it->fb_f__()->size()); \
for (const auto& v : *it->fb_f__()) { \
res.push_back(v); \
} \
return res; \
} \
template <> \
T OpDesc::GetAttr<T>(size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \
T res; \
res.reserve(it->fb_f__()->size()); \
for (const auto& v : *it->fb_f__()) { \
res.push_back(v); \
} \
return res; \
}
GET_ATTR_IMPL(int32_t, i);
GET_ATTR_IMPL(int16_t, block_idx);
GET_ATTR_IMPL(float, f);
GET_ATTR_IMPL(bool, b);
GET_ATTR_IMPL(int64_t, l);
GET_ATTRS_IMPL(std::vector<int>, ints);
GET_ATTRS_IMPL(std::vector<float>, floats);
GET_ATTRS_IMPL(std::vector<int64_t>, longs);
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "lite/model_parser/base/op_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace fbs {
class OpDesc : public OpDescReadAPI {
public:
explicit OpDesc(proto::OpDesc* desc) : desc_(desc) { CHECK(desc_); }
std::string Type() const override { return desc_->type()->str(); }
// Get the arguments of parameter called `param`
std::vector<std::string> Input(const std::string& param) const override {
const auto& var = desc_->inputs()->LookupByKey(param.c_str());
std::vector<std::string> args_vec;
if (var->arguments()) {
args_vec.reserve(var->arguments()->size());
for (const auto& in : *var->arguments()) {
args_vec.push_back(in->str());
}
}
return args_vec;
}
std::vector<std::string> InputArgumentNames() const override {
const auto& vars = desc_->inputs();
std::vector<std::string> input_names_vec;
if (vars) {
input_names_vec.reserve(vars->size());
for (const auto& in : *vars) {
input_names_vec.push_back(in->parameter()->str());
}
}
return input_names_vec;
}
std::vector<std::string> Output(const std::string& param) const override {
const auto& var = desc_->outputs()->LookupByKey(param.c_str());
std::vector<std::string> args_vec;
if (var->arguments()) {
args_vec.reserve(var->arguments()->size());
for (const auto& out : *var->arguments()) {
args_vec.push_back(out->str());
}
}
return args_vec;
}
std::vector<std::string> OutputArgumentNames() const override {
const auto& vars = desc_->outputs();
std::vector<std::string> output_names_vec;
if (vars) {
output_names_vec.reserve(vars->size());
for (const auto& out : *vars) {
output_names_vec.push_back(out->parameter()->str());
}
}
return output_names_vec;
}
bool HasAttr(const std::string& name) const override {
return desc_->attrs()->LookupByKey(name.c_str()) == nullptr;
}
size_t AttrsSize() const { return desc_->attrs()->size(); }
std::string AttrName(size_t idx) const {
return desc_->attrs()->Get(idx)->name()->str();
}
OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
const auto& attr = desc_->attrs()->LookupByKey(name.c_str());
CHECK(attr);
return static_cast<OpDescAPI::AttrType>(attr->type());
}
OpDescAPI::AttrType GetAttrType(size_t idx) const {
const auto& attr = desc_->attrs()->Get(idx);
CHECK(attr);
return static_cast<OpDescAPI::AttrType>(attr->type());
}
std::vector<std::string> AttrNames() const override {
const auto& attrs = desc_->attrs();
std::vector<std::string> attr_names_vec;
if (attrs) {
attr_names_vec.reserve(attrs->size());
for (const auto& attr : *attrs) {
attr_names_vec.push_back(attr->name()->str());
}
}
return attr_names_vec;
}
template <typename T>
T GetAttr(const std::string& name) const;
template <typename T>
T GetAttr(size_t idx) const;
OpDesc() = delete;
private:
proto::OpDesc* desc_;
};
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/program_desc.h"
namespace paddle {
namespace lite {
namespace fbs {
template <>
proto::BlockDesc* ProgramDesc::GetBlock<proto::BlockDesc>(int32_t idx) {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return const_cast<proto::BlockDesc*>(desc_->blocks()->Get(idx));
}
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/model_parser/base/program_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace fbs {
class ProgramDesc : public ProgramDescReadAPI {
public:
explicit ProgramDesc(proto::ProgramDesc *desc) : desc_(desc) { CHECK(desc); }
size_t BlocksSize() const override { return desc_->blocks()->size(); }
template <typename T>
T *GetBlock(int32_t idx);
template <typename T>
T const *GetBlock(int32_t idx) const {
return GetBlock<T>(idx);
}
bool HasVersion() const override { return desc_->version() != nullptr; }
int64_t Version() const override {
CHECK(HasVersion());
return desc_->version()->version();
}
private:
proto::ProgramDesc *desc_; // not_own
};
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/var_desc.h"
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "lite/model_parser/base/var_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace fbs {
class VarDesc : public VarDescReadAPI {
public:
explicit VarDesc(proto::VarDesc* desc) : desc_(desc) {}
std::string Name() const override { return desc_->name()->str(); }
VarDescAPI::Type GetType() const override {
return static_cast<VarDescAPI::Type>(desc_->type()->type());
}
bool Persistable() const override { return desc_->persistable(); }
std::vector<int64_t> GetShape() const override {
CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR);
const auto& dims = desc_->type()->lod_tensor()->tensor()->dims();
std::vector<int64_t> dims_vec;
dims_vec.reserve(dims->size());
for (const auto& dim : *dims) {
dims_vec.push_back(dim);
}
return dims_vec;
}
VarDesc() = delete;
private:
proto::VarDesc* desc_;
};
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -39,6 +39,7 @@ add_operator(unsqueeze_op_lite basic SRCS unsqueeze_op.cc DEPS ${op_DEPS})
add_operator(stack_op basic SRCS stack_op.cc DEPS ${op_DEPS})
add_operator(cast_op_lite basic SRCS cast_op.cc DEPS ${op_DEPS})
add_operator(affine_channel_op basic SRCS affine_channel_op.cc DEPS ${op_DEPS})
add_operator(affine_grid_op basic SRCS affine_grid_op.cc DEPS ${op_DEPS})
add_operator(range_op basic SRCS range_op.cc DEPS ${op_DEPS})
add_operator(reduce_mean_op basic SRCS reduce_mean_op.cc DEPS ${op_DEPS})
add_operator(relu_op basic SRCS relu_op.cc DEPS ${op_DEPS})
......@@ -110,6 +111,7 @@ add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposal
add_operator(crf_decoding_op_lite extra SRCS crf_decoding_op.cc DEPS ${op_DEPS})
add_operator(ctc_align_op_lite extra SRCS ctc_align_op.cc DEPS ${op_DEPS})
add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${op_DEPS})
add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS})
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......@@ -137,6 +139,7 @@ add_operator(topk_op extra SRCS topk_op.cc DEPS ${op_DEPS})
add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS})
add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS})
add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS})
add_operator(retinanet_detection_output_op extra SRCS retinanet_detection_output_op.cc DEPS ${op_DEPS})
add_operator(where_index_op extra SRCS where_index_op.cc DEPS ${op_DEPS})
# for content-dnn specific
add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/affine_grid_op.h"
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool AffineGridOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
const auto x_dims = param_.X->dims();
CHECK_OR_FALSE(x_dims.size() == 3);
CHECK_OR_FALSE(x_dims[1] == 2 && x_dims[2] == 3);
if (param_.output_shape.size() != 0) {
CHECK_OR_FALSE(param_.output_shape.size() == 4);
}
return true;
}
bool AffineGridOpLite::InferShapeImpl() const {
int N = param_.X->dims()[0];
int H, W;
if (param_.output_shape.size() == 0) {
const auto out_shape = param_.OutputShape->dims();
H = out_shape[2];
W = out_shape[3];
} else {
H = param_.output_shape[2];
W = param_.output_shape[3];
}
param_.Out->Resize(std::vector<int64_t>({N, H, W, 2}));
return true;
}
bool AffineGridOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto x = op_desc.Input("Theta").front();
auto output = op_desc.Output("Output").front();
param_.X = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.output_shape = op_desc.GetAttr<std::vector<int>>("output_shape");
param_.Out = scope->FindVar(output)->GetMutable<lite::Tensor>();
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(affine_grid, paddle::lite::operators::AffineGridOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class AffineGridOpLite : public OpLite {
public:
AffineGridOpLite() {}
explicit AffineGridOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "affine_grid"; }
private:
mutable AffineGridParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -1166,6 +1166,13 @@ struct AffineChannelParam : ParamBase {
lite::Tensor* Out{};
};
struct AffineGridParam : ParamBase {
const lite::Tensor* X{}; // Theta:shape {?, 2, 3}
std::vector<int> output_shape;
const lite::Tensor* OutputShape;
lite::Tensor* Out{};
};
struct AnchorGeneratorParam : ParamBase {
const lite::Tensor* Input{};
std::vector<float> anchor_sizes{};
......@@ -1568,6 +1575,20 @@ struct PixelShuffleParam : ParamBase {
lite::Tensor* output{nullptr};
int upscale_factor{1};
};
struct RetinanetDetectionOutputParam : ParamBase {
std::vector<Tensor*> bboxes{};
std::vector<Tensor*> scores{};
std::vector<Tensor*> anchors{};
Tensor* im_info{};
Tensor* out{};
float score_threshold{};
int nms_top_k{};
float nms_threshold{};
float nms_eta{};
int keep_top_k{};
};
struct WhereIndexParam : ParamBase {
const lite::Tensor* input{nullptr};
lite::Tensor* output{nullptr};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/pixel_shuffle_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool PixelShuffleOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
CHECK_OR_FALSE(param_.upscale_factor);
const auto x_dims = param_.x->dims();
const auto upscale_factor = param_.upscale_factor;
CHECK_EQ_OR_FALSE(x_dims[1] % (upscale_factor * upscale_factor), 0);
return true;
}
bool PixelShuffleOpLite::InferShapeImpl() const {
const auto x_dims = param_.x->dims();
const auto upscale_factor = param_.upscale_factor;
auto output_dims = x_dims;
output_dims[0] = x_dims[0];
output_dims[1] = x_dims[1] / (upscale_factor * upscale_factor);
output_dims[2] = x_dims[2] * upscale_factor;
output_dims[3] = x_dims[3] * upscale_factor;
param_.output->Resize(output_dims);
return true;
}
bool PixelShuffleOpLite::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) {
auto input = opdesc.Input("X").front();
auto out = opdesc.Output("Out").front();
param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
if (opdesc.HasAttr("upscale_factor")) {
param_.upscale_factor = opdesc.GetAttr<int>("upscale_factor");
}
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(pixel_shuffle, paddle::lite::operators::PixelShuffleOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class PixelShuffleOpLite : public OpLite {
public:
PixelShuffleOpLite() {}
explicit PixelShuffleOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "pixel_shuffle"; }
private:
mutable PixelShuffleParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/retinanet_detection_output_op.h"
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool RetinanetDetectionOutputOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.bboxes.size() > 0);
CHECK_OR_FALSE(param_.scores.size() > 0);
CHECK_OR_FALSE(param_.anchors.size() > 0);
CHECK_OR_FALSE(param_.bboxes.size() == param_.scores.size());
CHECK_OR_FALSE(param_.bboxes.size() == param_.anchors.size());
CHECK_OR_FALSE(param_.im_info);
CHECK_OR_FALSE(param_.out);
DDim bbox_dims = param_.bboxes.front()->dims();
DDim score_dims = param_.scores.front()->dims();
DDim anchor_dims = param_.anchors.front()->dims();
DDim im_info_dims = param_.im_info->dims();
CHECK_OR_FALSE(bbox_dims.size() == 3);
CHECK_OR_FALSE(score_dims.size() == 3);
CHECK_OR_FALSE(anchor_dims.size() == 2);
CHECK_OR_FALSE(bbox_dims[2] == 4);
CHECK_OR_FALSE(bbox_dims[1] == score_dims[1]);
CHECK_OR_FALSE(anchor_dims[0] == bbox_dims[1]);
CHECK_OR_FALSE(im_info_dims.size() == 2);
return true;
}
bool RetinanetDetectionOutputOpLite::InferShapeImpl() const {
DDim bbox_dims = param_.bboxes.front()->dims();
param_.out->Resize({bbox_dims[1], bbox_dims[2] + 2});
return true;
}
bool RetinanetDetectionOutputOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
for (auto arg_name : op_desc.Input("BBoxes")) {
param_.bboxes.push_back(
scope->FindVar(arg_name)->GetMutable<lite::Tensor>());
}
for (auto arg_name : op_desc.Input("Scores")) {
param_.scores.push_back(
scope->FindVar(arg_name)->GetMutable<lite::Tensor>());
}
for (auto arg_name : op_desc.Input("Anchors")) {
param_.anchors.push_back(
scope->FindVar(arg_name)->GetMutable<lite::Tensor>());
}
AttachInput(op_desc, scope, "ImInfo", false, &param_.im_info);
AttachOutput(op_desc, scope, "Out", false, &param_.out);
param_.score_threshold = op_desc.GetAttr<float>("score_threshold");
param_.nms_top_k = op_desc.GetAttr<int>("nms_top_k");
param_.nms_threshold = op_desc.GetAttr<float>("nms_threshold");
param_.nms_eta = op_desc.GetAttr<float>("nms_eta");
param_.keep_top_k = op_desc.GetAttr<int>("keep_top_k");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(retinanet_detection_output,
paddle::lite::operators::RetinanetDetectionOutputOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class RetinanetDetectionOutputOpLite : public OpLite {
public:
RetinanetDetectionOutputOpLite() {}
explicit RetinanetDetectionOutputOpLite(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "retinanet_detection_output";
}
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {}
#endif
private:
mutable RetinanetDetectionOutputParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Subproject commit ac203b20926b13a35ff85277d2e5d3c38698eee8
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册