未验证 提交 c2f72cb3 编写于 作者: J juncaipeng 提交者: GitHub

Support mask_rcnn (#2484)

* add arm split lod tensor, test=develop

* add arm merge lod tensor, test=develop

* update split merge lod tensor, test=develop

* add reduce_prob op, test=develop

* support mask_rcnn succeed, test=develop
上级 e17295cc
......@@ -120,5 +120,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
stack.cc
affine_channel.cc
anchor_generator.cc
split_merge_lod_tenosr.cc
reduce_prod.cc
DEPS ${lite_kernel_deps} context tensor)
endif()
......@@ -51,6 +51,7 @@
#include "lite/backends/arm/math/prior_box.h"
#include "lite/backends/arm/math/reduce_max.h"
#include "lite/backends/arm/math/reduce_mean.h"
#include "lite/backends/arm/math/reduce_prod.h"
#include "lite/backends/arm/math/scale.h"
#include "lite/backends/arm/math/sequence_expand.h"
#include "lite/backends/arm/math/sequence_pool.h"
......@@ -61,6 +62,7 @@
#include "lite/backends/arm/math/slice.h"
#include "lite/backends/arm/math/softmax.h"
#include "lite/backends/arm/math/split.h"
#include "lite/backends/arm/math/split_merge_lod_tenosr.h"
#include "lite/backends/arm/math/stack.h"
#include "lite/backends/arm/math/topk.h"
#include "lite/backends/arm/math/yolo_box.h"
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/backends/arm/math/reduce_prod.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void reduce_prod_n(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int hw_size = height_in * width_in;
int chw_size = channel_in * hw_size;
int data_index, src_index, src_index0;
for (int c = 0; c < channel_in; ++c) {
for (int h = 0; h < height_in; ++h) {
for (int w = 0; w < width_in; ++w) {
data_index = c * hw_size + h * width_in + w;
dst[data_index] = static_cast<T>(1);
for (int n = 0; n < num_in; ++n) {
src_index = n * chw_size + data_index;
dst[data_index] *= src[src_index];
}
}
}
}
}
template <typename T>
void reduce_prod_c(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int hw_size = height_in * width_in;
int chw_size = hw_size * channel_in;
int data_index, src_index0, src_index;
for (int n = 0; n < num_in; ++n) {
for (int h = 0; h < height_in; ++h) {
for (int w = 0; w < width_in; ++w) {
data_index = n * hw_size + h * width_in + w;
src_index0 = n * chw_size + h * width_in + w;
dst[data_index] = static_cast<T>(1);
for (int c = 0; c < channel_in; ++c) {
src_index = src_index0 + c * hw_size;
dst[data_index] *= src[src_index];
}
}
}
}
}
template <typename T>
void reduce_prod_h(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int cw_size = channel_in * width_in;
int chw_size = cw_size * height_in;
int hw_size = height_in * width_in;
int data_index, src_index, src_index0;
for (int n = 0; n < num_in; ++n) {
for (int c = 0; c < channel_in; ++c) {
for (int w = 0; w < width_in; ++w) {
data_index = n * cw_size + c * width_in + w;
src_index0 = n * chw_size + c * hw_size + w;
dst[data_index] = static_cast<T>(1);
for (int h = 0; h < height_in; ++h) {
src_index = src_index0 + h * width_in;
dst[data_index] *= src[src_index];
}
}
}
}
}
template <typename T>
void reduce_prod_w(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int ch_size = channel_in * height_in;
int hw_size = height_in * width_in;
int chw_size = ch_size * width_in;
int data_index = 0;
int src_index0 = 0;
int src_index = 0;
for (int n = 0; n < num_in; ++n) {
for (int c = 0; c < channel_in; ++c) {
for (int h = 0; h < height_in; ++h) {
data_index = n * ch_size + c * height_in + h;
src_index0 = n * chw_size + c * hw_size + h * width_in;
dst[data_index] = static_cast<T>(1);
for (int w = 0; w < width_in; ++w) {
src_index = src_index0 + w;
dst[data_index] *= src[src_index];
}
}
}
}
}
template <typename T>
void reduce_prod_nc(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
// reduce n first.
DDimLite ddimA({1, channel_in, height_in, width_in});
lite::Tensor tensor_tmp;
tensor_tmp.Resize(ddimA);
auto* tmp_out = tensor_tmp.mutable_data<T>();
reduce_prod_n(src, tmp_out, num_in, channel_in, height_in, width_in);
reduce_prod_c(tmp_out, dst, 1, channel_in, height_in, width_in);
}
template <typename T>
void reduce_prod_ch(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
// reduce c first
DDimLite ddimA({num_in, 1, height_in, width_in});
lite::Tensor tensor_tmp;
tensor_tmp.Resize(ddimA);
auto* tmp_out = tensor_tmp.mutable_data<T>();
reduce_prod_c(src, tmp_out, num_in, channel_in, height_in, width_in);
reduce_prod_h(tmp_out, dst, num_in, 1, height_in, width_in);
}
template <typename T>
void reduce_prod_hw(const T* src,
T* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
// reduce h first
DDimLite ddimA({num_in, channel_in, 1, width_in});
lite::Tensor tensor_tmp;
tensor_tmp.Resize(ddimA);
auto* tmp_out = tensor_tmp.mutable_data<T>();
reduce_prod_h(src, tmp_out, num_in, channel_in, height_in, width_in);
reduce_prod_w(tmp_out, dst, num_in, channel_in, 1, width_in);
}
template <typename T>
void reduce_prod_all(const T* src, T* dst, int64_t total_num) {
dst[0] = static_cast<T>(1);
for (int n = 0; n < total_num; ++n) {
dst[0] *= src[n];
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -86,6 +86,13 @@ template void slice(const int* input,
std::vector<int> ends,
int* out,
Context<TARGET(kARM)>* ctx);
template void slice(const float* input,
std::vector<int64_t> dims,
std::vector<int> axes,
std::vector<int> starts,
std::vector<int> ends,
float* out,
Context<TARGET(kARM)>* ctx);
} // namespace math
} // namespace arm
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/split_merge_lod_tenosr.h"
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
using LoDAndOffset = std::pair<LoD, std::pair<size_t, size_t>>;
LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD &lod,
size_t start_idx,
size_t end_idx,
size_t start_level) {
LoD sub_lod;
for (size_t level_idx = start_level; level_idx < lod.size(); ++level_idx) {
CHECK(start_idx <= end_idx);
CHECK(end_idx < lod[level_idx].size());
std::vector<uint64_t> level_lens;
for (size_t i = start_idx; i < end_idx; ++i) {
level_lens.push_back(lod[level_idx][i + 1] - lod[level_idx][i]);
}
sub_lod.emplace_back(level_lens);
start_idx = lod[level_idx][start_idx];
end_idx = lod[level_idx][end_idx];
}
return LoDAndOffset{sub_lod, {start_idx, end_idx}};
}
void AppendLoD(LoD *lod, const LoD &lod_length) {
CHECK(lod->empty() || lod->size() == lod_length.size());
if (lod->empty()) {
for (size_t i = 0; i < lod_length.size(); ++i) {
lod->emplace_back(std::vector<uint64_t>({0}));
}
}
for (size_t i = 0; i < lod->size(); ++i) {
auto &level = (*lod)[i];
for (auto len : lod_length[i]) {
level.push_back(level.back() + len);
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <utility>
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
std::pair<LoD, std::pair<size_t, size_t>> GetSubLoDAndAbsoluteOffset(
const LoD &lod, size_t start_idx, size_t end_idx, size_t start_level);
void AppendLoD(LoD *lod, const LoD &lod_length);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -54,6 +54,8 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
CREATE_KERNEL1(target__, kFP16); \
case PRECISION(kAny): \
CREATE_KERNEL1(target__, kAny); \
case PRECISION(kInt32): \
CREATE_KERNEL1(target__, kInt32); \
case PRECISION(kInt64): \
CREATE_KERNEL1(target__, kInt64); \
default: \
......@@ -136,6 +138,7 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kARM, kInt8, kNCHW);
INIT_FOR(kARM, kAny, kNCHW);
INIT_FOR(kARM, kAny, kAny);
INIT_FOR(kARM, kInt32, kNCHW);
INIT_FOR(kOpenCL, kFloat, kNCHW);
INIT_FOR(kOpenCL, kFloat, kNHWC);
......
......@@ -145,6 +145,9 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kARM),
PRECISION(kInt32),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC)> *, //
......
......@@ -17,6 +17,7 @@
#include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/cpp/op_desc.h"
#include "lite/model_parser/cpp/var_desc.h"
#include "lite/operators/conditional_block_op.h"
#include "lite/operators/while_op.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/precision_profiler.h"
......@@ -141,12 +142,17 @@ void Program::Build(const cpp::ProgramDesc& prog) {
VLOG(4) << "create Op [" << op_type << "]";
auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type;
if (op_type == "while") {
if (op_type == "while" || op_type == "conditional_block") {
auto sub_block_idx = op_desc.GetAttr<int32_t>("sub_block");
auto sub_block =
const_cast<cpp::ProgramDesc&>(prog).GetBlock<cpp::BlockDesc>(
sub_block_idx);
static_cast<operators::WhileOpLite*>(op.get())->SetSubBlock(sub_block);
if (op_type == "while") {
static_cast<operators::WhileOpLite*>(op.get())->SetSubBlock(sub_block);
} else if (op_type == "conditional_block") {
static_cast<operators::ConditionalBlockOpLite*>(op.get())->SetSubBlock(
sub_block);
}
}
ops_.emplace_back(std::move(op));
ops_.back()->Attach(op_desc, exec_scope_);
......
......@@ -61,11 +61,16 @@ add_kernel(im2sequence_compute_arm ARM extra SRCS im2sequence_compute.cc DEPS ${
add_kernel(sequence_pool_compute_arm ARM extra SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(layer_norm_compute_arm ARM extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(gather_compute_arm ARM extra SRCS gather_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(reduce_prod_compute_arm ARM extra SRCS reduce_prod_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(split_lod_tensor_compute_arm ARM extra SRCS split_lod_tensor_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(merge_lod_tensor_compute_arm ARM extra SRCS merge_lod_tensor_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(anchor_generator_compute_arm ARM extra SRCS anchor_generator_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(generate_proposals_compute_arm ARM extra SRCS generate_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(roi_align_compute_arm ARM extra SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(box_clip_compute_arm ARM extra SRCS box_clip_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(assign_value_compute_arm ARM extra SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(conditional_block_compute_arm ARM extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(collect_fpn_proposals_compute_arm ARM extra SRCS collect_fpn_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
# for OCR specific
......@@ -107,6 +112,8 @@ lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS tran
lite_cc_test(test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm)
lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_compute_arm)
if(LITE_BUILD_EXTRA)
lite_cc_test(test_split_lod_tensor_compute_arm SRCS split_lod_tensor_compute_test.cc DEPS split_lod_tensor_compute_arm)
lite_cc_test(test_merge_lod_tensor_compute_arm SRCS merge_lod_tensor_compute_test.cc DEPS merge_lod_tensor_compute_arm)
lite_cc_test(test_lrn_compute_arm SRCS lrn_compute_test.cc DEPS lrn_compute_arm)
lite_cc_test(test_decode_bboxes_compute_arm SRCS decode_bboxes_compute_test.cc DEPS decode_bboxes_compute_arm)
lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/collect_fpn_proposals_compute.h"
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
struct ScoreWithID {
float score;
int batch_id;
int index;
int level;
ScoreWithID() {
batch_id = -1;
index = -1;
level = -1;
}
ScoreWithID(float score_, int batch_id_, int index_, int level_) {
score = score_;
batch_id = batch_id_;
index = index_;
level = level_;
}
};
static inline bool CompareByScore(ScoreWithID a, ScoreWithID b) {
return a.score >= b.score;
}
static inline bool CompareByBatchid(ScoreWithID a, ScoreWithID b) {
return a.batch_id < b.batch_id;
}
void CollectFpnProposalsCompute::Run() {
auto& param = Param<operators::CollectFpnProposalsParam>();
auto multi_layer_rois = param.multi_level_rois;
auto multi_layer_scores = param.multi_level_scores;
auto* fpn_rois = param.fpn_rois;
int post_nms_topN = param.post_nms_topN;
if (multi_layer_rois.size() != multi_layer_scores.size()) {
LOG(FATAL) << "multi_layer_rois.size() should be equan to "
"multi_layer_scores.size()";
}
size_t num_fpn_level = multi_layer_rois.size();
std::vector<int> integral_of_all_rois(num_fpn_level + 1, 0);
for (size_t i = 0; i < num_fpn_level; ++i) {
auto cur_rois_lod = multi_layer_rois[i]->lod().back();
integral_of_all_rois[i + 1] = static_cast<int>(
integral_of_all_rois[i] + cur_rois_lod[cur_rois_lod.size() - 1]);
}
std::vector<ScoreWithID> scores_of_all_rois(
integral_of_all_rois[num_fpn_level], ScoreWithID());
for (int i = 0; i < num_fpn_level; ++i) {
const float* cur_level_scores = multi_layer_scores[i]->data<float>();
int cur_level_num = integral_of_all_rois[i + 1] - integral_of_all_rois[i];
auto cur_scores_lod = multi_layer_scores[i]->lod().back();
int cur_batch_id = 0;
for (int j = 0; j < cur_level_num; ++j) {
if (j >= cur_scores_lod[cur_batch_id + 1]) {
cur_batch_id++;
}
int cur_index = j + integral_of_all_rois[i];
scores_of_all_rois[cur_index].score = cur_level_scores[j];
scores_of_all_rois[cur_index].index = j;
scores_of_all_rois[cur_index].level = i;
scores_of_all_rois[cur_index].batch_id = cur_batch_id;
}
}
// keep top post_nms_topN rois, sort the rois by the score
if (post_nms_topN > integral_of_all_rois[num_fpn_level]) {
post_nms_topN = integral_of_all_rois[num_fpn_level];
}
std::stable_sort(
scores_of_all_rois.begin(), scores_of_all_rois.end(), CompareByScore);
scores_of_all_rois.resize(post_nms_topN);
// sort by batch id
std::stable_sort(
scores_of_all_rois.begin(), scores_of_all_rois.end(), CompareByBatchid);
// create a pointer array
std::vector<const float*> multi_fpn_rois_data(num_fpn_level);
for (int i = 0; i < num_fpn_level; ++i) {
multi_fpn_rois_data[i] = multi_layer_rois[i]->data<float>();
}
// initialize the outputs
const int kBoxDim = 4;
auto fpn_rois_data = fpn_rois->mutable_data<float>();
std::vector<uint64_t> lod0(1, 0);
int cur_batch_id = 0;
for (int i = 0; i < post_nms_topN; ++i) {
int cur_fpn_level = scores_of_all_rois[i].level;
int cur_level_index = scores_of_all_rois[i].index;
std::memcpy(fpn_rois_data,
multi_fpn_rois_data[cur_fpn_level] + cur_level_index * kBoxDim,
kBoxDim * sizeof(float));
fpn_rois_data += kBoxDim;
if (scores_of_all_rois[i].batch_id != cur_batch_id) {
cur_batch_id = scores_of_all_rois[i].batch_id;
lod0.emplace_back(i);
}
}
lod0.emplace_back(post_nms_topN);
lite::LoD lod;
lod.emplace_back(lod0);
fpn_rois->set_lod(lod);
return;
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(collect_fpn_proposals,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::CollectFpnProposalsCompute,
def)
.BindInput("MultiLevelRois", {LiteType::GetTensorListTy(TARGET(kARM))})
.BindInput("MultiLevelScores", {LiteType::GetTensorListTy(TARGET(kARM))})
.BindOutput("FpnRois", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/axpy_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class CollectFpnProposalsCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::CollectFpnProposalsParam;
void Run() override;
virtual ~CollectFpnProposalsCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -112,6 +112,42 @@ void CompareCompute<Functor>::Run() {
}
}
template <template <typename T> class Functor>
void CompareCompute_int32<Functor>::Run() {
auto &param = this->Param<operators::CompareParam>();
using CompareFunctor = Functor<int>;
const size_t x_size = param.X->numel();
const size_t y_size = param.Y->numel();
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
bool *z = param.Out->template mutable_data<bool>();
const auto *x = param.X->template data<int>();
const auto *y = param.Y->template data<int>();
auto axis = param.axis;
bool force_cpu = param.force_cpu;
if (x_size == y_size) {
for (int i = 0; i < x_size; ++i) {
z[i] = CompareFunctor()(x[i], y[i]);
}
} else {
int axis = (param.axis == -1 ? x_dims.size() - y_dims.size() : param.axis);
int outer_num, mid_num, inner_num;
get_mid_dims(x_dims, y_dims, axis, &outer_num, &mid_num, &inner_num);
for (int outer_id = 0; outer_id < outer_num; ++outer_id) {
for (int mid_id = 0; mid_id < mid_num; ++mid_id) {
auto y_data = y[mid_id];
for (int inner_id = 0; inner_id < inner_num; ++inner_id) {
int index = (outer_id * mid_num + mid_id) * inner_num + inner_id;
z[index] = CompareFunctor()(x[index], y_data);
// z[index] = x[index] < y_data;
}
}
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
......@@ -183,3 +219,14 @@ REGISTER_LITE_KERNEL(greater_equal,
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize();
REGISTER_LITE_KERNEL(equal,
kARM,
kInt32,
kNCHW,
paddle::lite::kernels::arm::CompareCompute_int32<
paddle::lite::kernels::arm::_EqualFunctor>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize();
......@@ -33,8 +33,17 @@ class CompareCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
void Run() override;
~CompareCompute() {}
};
template <template <typename T> class Functor>
class CompareCompute_int32
: public KernelLite<TARGET(kARM), PRECISION(kInt32)> {
public:
using param_t = operators::LogicalParam;
void Run() override;
private:
~CompareCompute_int32() {}
};
} // namespace arm
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/conditional_block_compute.h"
#include <memory>
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void ConditionalBlockCompute::PrepareForRun() {
auto& param = Param<operators::ConditionalBlockParam>();
auto cur_scope = param.scope;
executor_ =
std::make_shared<CondExecutor>(param.sub_block, cur_scope, place());
}
void ConditionalBlockCompute::Run() {
auto& param = Param<operators::ConditionalBlockParam>();
bool need_run = true;
if (param.is_scalar_condition) {
auto* cond = param.cond;
auto* cond_data = cond->data<bool>();
need_run = cond_data[0];
} else {
auto x = param.x;
for (auto pt : x) {
if (pt == nullptr || !pt->IsInitialized() || pt->dims().empty()) {
need_run = false;
break;
}
}
}
if (need_run) {
executor_->Run();
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(conditional_block,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ConditionalBlockCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Cond", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Scope", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/program.h"
#include "lite/operators/conditional_block_op.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/basic_profiler.h"
#endif // LITE_WITH_PROFILE
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/precision_profiler.h"
#endif
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class CondExecutor {
typedef std::shared_ptr<OpLite> OpPtr;
public:
CondExecutor(cpp::BlockDesc *block, Scope *scope, Place place)
: scope_(scope), place_(place) {
int32_t op_size = block->OpsSize();
for (int32_t i = 0; i < op_size; ++i) {
auto &op_desc = *block->template GetOp<cpp::OpDesc>(i);
auto op_type = op_desc.Type();
auto op_handler = lite::LiteOpRegistry::Global().Create(op_desc.Type());
op_handler->Attach(op_desc, scope);
auto hostplace = place_;
hostplace.target = TARGET(kHost);
auto kernels = op_handler->CreateKernels({place_, hostplace});
CHECK_GT(kernels.size(), 0) << "cannot create kernel";
op_handler->AttachKernel(kernels[0].get());
op_handler->SetKernel(kernels);
ops_of_block_.push_back(op_handler);
}
}
void Run() {
for (auto &op_handler : ops_of_block_) {
op_handler->CheckShape();
op_handler->InferShape();
#ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PRECISION_PROFILE
std::unique_ptr<KernelBase> kernel(op_handler->GetKernel());
Instruction inst(op_handler, std::move(kernel));
#endif // LITE_WITH_PRECISION_PROFILE
#endif // LITE_WITH_PROFILE
op_handler->Run();
#ifdef LITE_WITH_PROFILE
#ifdef LITE_WITH_PRECISION_PROFILE
LITE_PRECISION_PROFILE(inst)
#endif // LITE_WITH_PRECISION_PROFILE
#endif // LITE_WITH_PROFILE
}
}
private:
Scope *scope_;
Place place_;
std::vector<OpPtr> ops_of_block_;
};
class ConditionalBlockCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ConditionalBlockParam;
void PrepareForRun() override;
void Run() override;
virtual ~ConditionalBlockCompute() = default;
private:
std::shared_ptr<CondExecutor> executor_;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -20,8 +20,7 @@ namespace lite {
namespace kernels {
namespace arm {
template <typename T>
class FillConstantCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
class FillConstantCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public:
using param_t = operators::FillConstantParam;
......@@ -86,9 +85,8 @@ class FillConstantCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual ~FillConstantCompute() = default;
};
template <typename T>
class FillConstantBatchLikeCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
: public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public:
using param_t = operators::FillConstantBatchLikeParam;
......@@ -135,24 +133,23 @@ class FillConstantBatchLikeCompute
// float
REGISTER_LITE_KERNEL(fill_constant,
kARM,
kFloat,
kAny,
kNCHW,
paddle::lite::kernels::arm::FillConstantCompute<float>,
paddle::lite::kernels::arm::FillConstantCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("ShapeTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(
fill_constant_batch_size_like,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::FillConstantBatchLikeCompute<float>,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
REGISTER_LITE_KERNEL(fill_constant_batch_size_like,
kARM,
kAny,
kNCHW,
paddle::lite::kernels::arm::FillConstantBatchLikeCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/merge_lod_tensor_compute.h"
#include <string>
#include <utility>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
struct CopyRange {
size_t begin;
size_t end;
};
void MergeLodTensorCompute::Run() {
auto &param = Param<operators::MergeLodTensorParam>();
const lite::Tensor *x = param.x;
const lite::Tensor *mask = param.mask;
const lite::Tensor *in_true = param.in_true;
const lite::Tensor *in_false = param.in_false;
lite::Tensor *out = param.out;
int level = param.level;
CHECK(in_true->IsInitialized() || in_false->IsInitialized());
auto &in_true_dim = in_true->dims();
auto &in_false_dim = in_false->dims();
// only merge the first dim
int64_t batch_size = 0;
std::vector<int64_t> out_shape;
if (in_true->IsInitialized()) {
batch_size += in_true->dims()[0];
}
if (in_false->IsInitialized()) {
batch_size += in_false->dims()[0];
}
out_shape.push_back(batch_size);
if (in_true->IsInitialized()) {
for (int i = 1; i < in_true_dim.size(); i++) {
out_shape.push_back(in_true_dim[i]);
}
} else {
for (int i = 1; i < in_false_dim.size(); i++) {
out_shape.push_back(in_false_dim[i]);
}
}
out->Resize(out_shape);
size_t base_num = static_cast<size_t>(out->numel() / batch_size);
auto *out_data = out->mutable_data<float>();
auto *out_lod = out->mutable_lod();
out_lod->clear();
auto &mask_dim = mask->dims();
auto *mask_data = mask->data<bool>();
size_t out_offset = 0;
size_t in_true_idx = 0;
size_t in_false_idx = 0;
for (size_t i = 0; i < static_cast<size_t>(mask_dim[0]); i++) {
const Tensor *input = nullptr;
size_t *in_idx = nullptr;
if (static_cast<int>(mask_data[i]) == 0) {
input = in_false;
in_idx = &in_false_idx;
} else {
input = in_true;
in_idx = &in_true_idx;
}
auto lod_and_offset = lite::arm::math::GetSubLoDAndAbsoluteOffset(
input->lod(), *in_idx, (*in_idx) + 1, 0);
auto &lod_length = lod_and_offset.first;
lite::arm::math::AppendLoD(out_lod, lod_length);
size_t start_offset = lod_and_offset.second.first;
size_t end_offset = lod_and_offset.second.second;
CHECK(end_offset >= start_offset);
size_t len = end_offset - start_offset;
if (len == 0) {
continue;
}
auto *in_src = input->data<float>() + base_num * start_offset;
auto *out_dest = out_data + base_num * out_offset;
size_t copy_num = base_num * len * sizeof(float);
memcpy(out_dest, in_src, copy_num);
out_offset += len;
(*in_idx) += 1;
}
for (size_t i = 0; i < level; i++) {
out_lod->insert(out_lod->begin(), x->lod()[i]);
}
return;
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(merge_lod_tensor,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::MergeLodTensorCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Mask", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.BindInput("InTrue", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("InFalse", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/merge_lod_tensor_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class MergeLodTensorCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::MergeLodTensorParam;
void Run() override;
virtual ~MergeLodTensorCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/merge_lod_tensor_compute.h"
#include <gtest/gtest.h>
#include <cstdlib>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
TEST(merge_lod_tensor_arm, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"merge_lod_tensor");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
TEST(merge_lod_tensor_arm, init) {
MergeLodTensorCompute cpt;
ASSERT_EQ(cpt.precision(), PRECISION(kFloat));
ASSERT_EQ(cpt.target(), TARGET(kARM));
}
TEST(merge_lod_tensor_arm_0, compute) {
DeviceInfo::Init();
Tensor x;
Tensor mask;
Tensor in_true;
Tensor in_false;
Tensor out;
int level = 0;
// set dims and lod
mask.Resize({3, 1});
in_true.Resize({1, 1});
LoD in_true_lod;
std::vector<uint64_t> in_true_lod0 = {0, 1};
in_true_lod.push_back(in_true_lod0);
in_true.set_lod(in_true_lod);
in_false.Resize({4, 1});
LoD in_false_lod;
std::vector<uint64_t> in_false_lod0 = {0, 2, 4};
in_false_lod.push_back(in_false_lod0);
in_false.set_lod(in_false_lod);
// initialize data
auto* in_true_data = in_true.mutable_data<float>();
for (size_t i = 0; i < in_true.numel(); i++) {
in_true_data[i] = static_cast<float>(i);
}
auto* in_false_data = in_false.mutable_data<float>();
for (size_t i = 0; i < in_false.numel(); i++) {
in_false_data[i] = static_cast<float>(i + 1);
}
auto* mask_data = mask.mutable_data<bool>();
for (size_t i = 0; i < mask.numel(); i++) {
mask_data[i] = static_cast<bool>(i % 2);
}
// prepare kernel params and run to obtain output_data
MergeLodTensorCompute op;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
op.SetContext(std::move(ctx));
operators::MergeLodTensorParam param;
param.x = &x;
param.mask = &mask;
param.in_true = &in_true;
param.in_false = &in_false;
param.out = &out;
param.level = level;
op.SetParam(param);
op.Launch();
auto* out_data = out.data<float>();
std::vector<float> out_ref = {1, 2, 0, 3, 4};
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_data[i], out_ref[i], 1e-5);
}
}
TEST(merge_lod_tensor_arm_1, compute) {
DeviceInfo::Init();
Tensor x;
Tensor mask;
Tensor in_true;
Tensor in_false;
Tensor out;
int level = 0;
// set dims and lod
mask.Resize({3, 1});
in_true.Resize({3, 3});
LoD in_true_lod = {{0, 1}, {0, 3}};
in_true.set_lod(in_true_lod);
in_false.Resize({6, 3});
LoD in_false_lod = {{0, 2, 4}, {0, 1, 3, 5, 6}};
in_false.set_lod(in_false_lod);
// initialize data
auto* in_true_data = in_true.mutable_data<float>();
for (size_t i = 0; i < in_true.numel(); i++) {
in_true_data[i] = static_cast<float>(i);
}
auto* in_false_data = in_false.mutable_data<float>();
for (size_t i = 0; i < in_false.numel(); i++) {
in_false_data[i] = static_cast<float>(i + 1);
}
auto* mask_data = mask.mutable_data<bool>();
for (size_t i = 0; i < mask.numel(); i++) {
mask_data[i] = static_cast<bool>(i % 2);
}
// prepare kernel params and run to obtain output_data
MergeLodTensorCompute op;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
op.SetContext(std::move(ctx));
operators::MergeLodTensorParam param;
param.x = &x;
param.mask = &mask;
param.in_true = &in_true;
param.in_false = &in_false;
param.out = &out;
param.level = level;
op.SetParam(param);
op.Launch();
auto* out_data = out.data<float>();
std::vector<float> out_ref = {1, 2, 3, 4, 5, 6, 7, 8, 9,
0, 1, 2, 3, 4, 5, 6, 7, 8,
10, 11, 12, 13, 14, 15, 16, 17, 18};
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_data[i], out_ref[i], 1e-5);
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(merge_lod_tensor, kARM, kFloat, kNCHW, def);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/reduce_prod_compute.h"
#include <string>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename T, PrecisionType Ptype>
void ReduceProdCompute<T, Ptype>::Run() {
auto& param = this->template Param<operators::ReduceParam>();
auto* input = param.x->template data<T>();
auto x_dims = param.x->dims();
int x_rank = x_dims.size();
auto* output = param.output->template mutable_data<T>();
std::vector<int> dim = param.dim;
bool keep_dim = param.keep_dim;
bool reduce_all = param.reduce_all;
if (!dim.empty()) {
for (int i = 0; i < dim.size(); i++) {
if (dim[i] < 0) {
dim[i] += x_rank;
}
}
}
if (reduce_all) {
lite::arm::math::reduce_prod_all(input, output, x_dims.production());
} else {
CHECK_EQ(x_rank, 4U);
int n_in = x_dims[0];
int c_in = x_dims[1];
int h_in = x_dims[2];
int w_in = x_dims[3];
if (dim.size() == 1) {
switch (dim[0]) {
case 0:
lite::arm::math::reduce_prod_n(input, output, n_in, c_in, h_in, w_in);
break;
case 1:
lite::arm::math::reduce_prod_c(input, output, n_in, c_in, h_in, w_in);
break;
case 2:
lite::arm::math::reduce_prod_h(input, output, n_in, c_in, h_in, w_in);
break;
case 3:
lite::arm::math::reduce_prod_w(input, output, n_in, c_in, h_in, w_in);
break;
default:
LOG(FATAL) << "dim[0] should be less than 4.";
}
} else if (dim.size() == 2) {
if (dim[0] == 0 && dim[1] == 1) {
lite::arm::math::reduce_prod_nc(input, output, n_in, c_in, h_in, w_in);
} else if (dim[0] == 1 && dim[1] == 2) {
lite::arm::math::reduce_prod_ch(input, output, n_in, c_in, h_in, w_in);
} else if (dim[0] == 2 && dim[1] == 3) {
lite::arm::math::reduce_prod_hw(input, output, n_in, c_in, h_in, w_in);
} else {
LOG(FATAL)
<< "Only support the values of the dim are 0,1 1,2 or 2,3 for now.";
}
} else {
LOG(FATAL) << "dim's size over than 2, which is not supported now!!";
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
using reduce_prob_arm_int32 =
paddle::lite::kernels::arm::ReduceProdCompute<int, PRECISION(kInt32)>;
using reduce_prob_arm_float =
paddle::lite::kernels::arm::ReduceProdCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(
reduce_prod, kARM, kInt32, kNCHW, reduce_prob_arm_int32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize();
REGISTER_LITE_KERNEL(
reduce_prod, kARM, kFloat, kNCHW, reduce_prob_arm_float, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename T, PrecisionType Ptype>
class ReduceProdCompute : public KernelLite<TARGET(kARM), Ptype> {
public:
void Run() override;
virtual ~ReduceProdCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -37,5 +37,5 @@ void ShapeCompute::Run() {
REGISTER_LITE_KERNEL(
shape, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ShapeCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize();
......@@ -42,11 +42,10 @@ inline std::vector<int32_t> get_new_data_from_tensor(
return vec_new_data;
}
void SliceCompute::PrepareForRun() {}
void SliceCompute::Run() {
template <typename T, PrecisionType PType>
void SliceCompute<T, PType>::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<operators::SliceParam>();
auto& param = this->template Param<operators::SliceParam>();
auto in = param.X;
auto in_dims = in->dims();
......@@ -156,8 +155,8 @@ void SliceCompute::Run() {
}
auto new_out_dims = out->dims();
const auto* x_data = in->data<int>();
auto* o_data = out->mutable_data<int>();
const auto* x_data = in->template data<T>();
auto* o_data = out->template mutable_data<T>();
lite::arm::math::slice(
x_data, in_dims.data(), axes, starts, ends, o_data, &ctx);
}
......@@ -167,8 +166,9 @@ void SliceCompute::Run() {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
slice, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SliceCompute, def)
using slice_float =
paddle::lite::kernels::arm::SliceCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(slice, kARM, kFloat, kNCHW, slice_float, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("StartsTensor", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("EndsTensor", {LiteType::GetTensorTy(TARGET(kARM))})
......
......@@ -22,12 +22,12 @@ namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class SliceCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
template <typename T, PrecisionType PType>
class SliceCompute : public KernelLite<TARGET(kARM), PType> {
public:
using param_t = operators::SliceParam;
void PrepareForRun() override;
void Run() override;
~SliceCompute() {}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/split_lod_tensor_compute.h"
#include <string>
#include <utility>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
struct CopyRange {
size_t begin;
size_t end;
};
void SplitLodTensorCompute::Run() {
auto &param = Param<operators::SplitLodTensorParam>();
const lite::Tensor *x = param.x;
const lite::Tensor *mask = param.mask;
lite::Tensor *out_true = param.out_true;
lite::Tensor *out_false = param.out_false;
int level = param.level;
auto &x_lod = x->lod();
auto &mask_dim = mask->dims();
auto *mask_data = mask->data<bool>();
std::vector<std::vector<CopyRange>> copy_ranges(2);
// set out_true/out_false lod
for (size_t t = 0; t < 2; t++) {
LoD *lod = nullptr;
if (t == 0) {
lod = out_false->mutable_lod();
} else {
lod = out_true->mutable_lod();
}
lod->clear();
for (size_t i = 0; i < static_cast<size_t>(mask_dim[0]); i++) {
VLOG(4) << "mask: " << mask_data[i];
if (static_cast<size_t>(mask_data[i]) == t) {
size_t start_idx = i;
auto lod_and_offset = lite::arm::math::GetSubLoDAndAbsoluteOffset(
x_lod, start_idx, start_idx + 1, level);
auto &lod_length = lod_and_offset.first;
lite::arm::math::AppendLoD(lod, lod_length);
size_t start_offset = lod_and_offset.second.first;
size_t end_offset = lod_and_offset.second.second;
copy_ranges[t].emplace_back(CopyRange{start_offset, end_offset});
}
}
}
for (size_t t = 0; t < 2; ++t) {
Tensor *out;
if (t == 0) {
out = out_false;
} else {
out = out_true;
}
auto &ranges = copy_ranges[t];
size_t height = std::accumulate(
ranges.begin(), ranges.end(), 0UL, [](size_t a, const CopyRange &b) {
return a + b.end - b.begin;
});
auto x_dim = x->dims();
x_dim[0] = static_cast<int64_t>(height);
out->Resize(x_dim);
auto *x_data = x->data<float>();
auto *out_data = out->mutable_data<float>();
auto out_dim = out->dims();
size_t base_num = static_cast<size_t>(out->numel() / out_dim[0]);
size_t offset = 0;
for (auto &each_range : ranges) {
size_t len = each_range.end - each_range.begin;
if (len == 0) {
continue;
}
auto *x_from = x_data + base_num * each_range.begin;
auto *out_dest = out_data + base_num * offset;
size_t copy_num = base_num * len * sizeof(float);
memcpy(out_dest, x_from, copy_num);
offset += len;
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(split_lod_tensor,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::SplitLodTensorCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Mask", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.BindOutput("OutTrue", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("OutFalse", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/split_lod_tensor_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class SplitLodTensorCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::SplitLodTensorParam;
void Run() override;
virtual ~SplitLodTensorCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/split_lod_tensor_compute.h"
#include <gtest/gtest.h>
#include <cstdlib>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
TEST(split_lod_tensor_arm, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"split_lod_tensor");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
TEST(split_lod_tensor_arm, init) {
SplitLodTensorCompute cpt;
ASSERT_EQ(cpt.precision(), PRECISION(kFloat));
ASSERT_EQ(cpt.target(), TARGET(kARM));
}
TEST(split_lod_tensor_arm_0, compute) {
DeviceInfo::Init();
Tensor x;
Tensor mask;
Tensor out_true;
Tensor out_false;
int level = 0;
// set dims and lod
VLOG(5) << "set dims and lod";
x.Resize({5, 1});
LoD x_lod;
std::vector<uint64_t> x_lod0 = {0, 2, 3, 5};
x_lod.push_back(x_lod0);
x.set_lod(x_lod);
mask.Resize({3, 1});
out_true.Resize({5, 1});
out_false.Resize({5, 1});
// initialize data
VLOG(5) << "initialize data";
auto* x_data = x.mutable_data<float>();
for (size_t i = 0; i < x.numel(); i++) {
x_data[i] = static_cast<float>(i);
}
auto* mask_data = mask.mutable_data<bool>();
for (size_t i = 0; i < mask.numel(); i++) {
mask_data[i] = static_cast<bool>(i % 2);
}
// prepare kernel params and run to obtain output_data
VLOG(5) << "prepare kernel params";
SplitLodTensorCompute op;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
op.SetContext(std::move(ctx));
VLOG(5) << "run kernel";
operators::SplitLodTensorParam param;
param.x = &x;
param.mask = &mask;
param.out_true = &out_true;
param.out_false = &out_false;
param.level = level;
op.SetParam(param);
op.Launch();
VLOG(5) << "obtain results";
auto* out_true_data = out_true.data<float>();
std::vector<float> out_true_ref = {2};
for (int i = 0; i < out_true.numel(); i++) {
LOG(INFO) << out_true_data[i];
EXPECT_NEAR(out_true_data[i], out_true_ref[i], 1e-5);
}
auto* out_false_data = out_false.data<float>();
std::vector<float> out_false_ref = {0, 1, 3, 4};
for (int i = 0; i < out_false.numel(); i++) {
LOG(INFO) << out_false_data[i];
EXPECT_NEAR(out_false_data[i], out_false_ref[i], 1e-5);
}
}
TEST(split_lod_tensor_arm_1, compute) {
DeviceInfo::Init();
Tensor x;
Tensor mask;
Tensor out_true;
Tensor out_false;
int level = 0;
// set dims and lod
x.Resize({9, 3});
LoD x_lod;
std::vector<uint64_t> x_lod0 = {0, 2, 3, 5};
std::vector<uint64_t> x_lod1 = {0, 1, 3, 6, 8, 9};
x_lod.push_back(x_lod0);
x_lod.push_back(x_lod1);
x.set_lod(x_lod);
mask.Resize({3, 1});
out_true.Resize({9, 2});
out_false.Resize({9, 2});
// initialize data
auto* x_data = x.mutable_data<float>();
for (size_t i = 0; i < x.numel(); i++) {
x_data[i] = static_cast<float>(i);
}
auto* mask_data = mask.mutable_data<bool>();
for (size_t i = 0; i < mask.numel(); i++) {
mask_data[i] = static_cast<bool>(i % 2);
}
// prepare kernel params and run to obtain output_data
SplitLodTensorCompute op;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
op.SetContext(std::move(ctx));
operators::SplitLodTensorParam param;
param.x = &x;
param.mask = &mask;
param.out_true = &out_true;
param.out_false = &out_false;
param.level = level;
op.SetParam(param);
op.Launch();
auto* out_true_data = out_true.data<float>();
std::vector<float> out_true_ref = {9, 10, 11, 12, 13, 14, 15, 16, 17};
for (int i = 0; i < out_true.numel(); i++) {
EXPECT_NEAR(out_true_data[i], out_true_ref[i], 1e-5);
}
auto* out_false_data = out_false.data<float>();
std::vector<float> out_false_ref = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 18, 19, 20, 21, 22, 23, 24, 25, 26};
for (int i = 0; i < out_false.numel(); i++) {
EXPECT_NEAR(out_false_data[i], out_false_ref[i], 1e-5);
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(split_lod_tensor, kARM, kFloat, kNCHW, def);
......@@ -83,6 +83,9 @@ add_operator(assign_value_op extra SRCS assign_value_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS})
add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wise_dequantize_max_abs.cc DEPS ${op_DEPS})
add_operator(split_lod_tensor_op_lite extra SRCS split_lod_tensor_op.cc DEPS ${op_DEPS})
add_operator(merge_lod_tensor_op_lite extra SRCS merge_lod_tensor_op.cc DEPS ${op_DEPS})
add_operator(reduce_prod_op_lite extra SRCS reduce_prod_op.cc DEPS ${op_DEPS})
add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS})
add_operator(sequence_reverse_op_lite extra SRCS sequence_reverse_op.cc DEPS ${op_DEPS})
add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS})
......@@ -94,6 +97,8 @@ add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_
add_operator(var_conv_2d_op_lite extra SRCS var_conv_2d_op.cc DEPS ${op_DEPS})
add_operator(attention_padding_mask_op_lite extra SRCS attention_padding_mask_op.cc DEPS ${op_DEPS})
add_operator(sequence_arithmetic_op_lite extra SRCS sequence_arithmetic_op.cc DEPS ${op_DEPS})
add_operator(conditional_block_op_lite extra SRCS conditional_block_op.cc DEPS ${op_DEPS})
add_operator(collect_fpn_proposals_op_lite extra SRCS collect_fpn_proposals_op.cc DEPS ${op_DEPS})
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/collect_fpn_proposals_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool CollectFpnProposalsOpLite::CheckShape() const {
CHECK_OR_FALSE(!param_.multi_level_rois.empty());
CHECK_OR_FALSE(!param_.multi_level_scores.empty());
CHECK_OR_FALSE(param_.fpn_rois);
for (auto item : param_.multi_level_rois) {
auto dims = item->dims();
CHECK_OR_FALSE(dims[1] == 4);
}
for (auto item : param_.multi_level_scores) {
auto dims = item->dims();
CHECK_OR_FALSE(dims[1] == 2);
}
for (int i = 0; i < param_.multi_level_rois.size(); i++) {
auto roi = param_.multi_level_rois[i];
auto roi_lod = roi->lod();
auto score = param_.multi_level_scores[i];
auto score_lod = score->lod();
CHECK_OR_FALSE(roi_lod == score_lod);
}
return true;
}
bool CollectFpnProposalsOpLite::InferShape() const {
param_.fpn_rois->Resize({param_.post_nms_topN, 4});
return true;
}
bool CollectFpnProposalsOpLite::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
auto rois_names = op_desc.Input("MultiLevelRois");
for (const auto& var_name : rois_names) {
param_.multi_level_rois.push_back(
scope->FindVar(var_name)->GetMutable<lite::Tensor>());
}
auto scores_names = op_desc.Input("MultiLevelScores");
for (const auto& var_name : scores_names) {
param_.multi_level_scores.push_back(
scope->FindVar(var_name)->GetMutable<lite::Tensor>());
}
auto fpn_rois = op_desc.Output("FpnRois").front();
param_.fpn_rois = scope->FindVar(fpn_rois)->GetMutable<lite::Tensor>();
param_.post_nms_topN = op_desc.GetAttr<int>("post_nms_topN");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(collect_fpn_proposals,
paddle::lite::operators::CollectFpnProposalsOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class CollectFpnProposalsOpLite : public OpLite {
public:
CollectFpnProposalsOpLite() {}
explicit CollectFpnProposalsOpLite(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "collect_fpn_proposals"; }
private:
mutable CollectFpnProposalsParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/conditional_block_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool ConditionalBlockOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.cond);
CHECK_OR_FALSE(param_.sub_block);
CHECK_OR_FALSE(param_.scope);
return true;
}
bool ConditionalBlockOpLite::InferShape() const { return true; }
bool ConditionalBlockOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto condition = op_desc.Input("Cond").front();
param_.cond = scope->FindVar(condition)->GetMutable<lite::Tensor>();
auto inputs = op_desc.Input("Input");
for (auto var : inputs) {
param_.x.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
auto outs = op_desc.Output("Out");
for (auto var : outs) {
param_.outs.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
param_.is_scalar_condition = op_desc.GetAttr<bool>("is_scalar_condition");
// obtain sub_block in core program.cc
param_.sub_block = sub_block_;
param_.scope = scope;
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(conditional_block,
paddle::lite::operators::ConditionalBlockOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class ConditionalBlockOpLite : public OpLite {
public:
ConditionalBlockOpLite() {}
explicit ConditionalBlockOpLite(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "conditional_block"; }
void SetSubBlock(cpp::BlockDesc *desc) { sub_block_ = desc; }
private:
mutable ConditionalBlockParam param_;
cpp::BlockDesc *sub_block_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/merge_lod_tensor_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool MergeLodTensorOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.mask);
CHECK_OR_FALSE(param_.in_true);
CHECK_OR_FALSE(param_.in_false);
CHECK_OR_FALSE(param_.out);
const auto mask_dims = param_.mask->dims();
CHECK_OR_FALSE(mask_dims.size() == 2);
CHECK_OR_FALSE(mask_dims[1] == 1);
return true;
}
bool MergeLodTensorOpLite::InferShape() const {
auto dims = param_.in_true->dims();
param_.out->Resize(dims);
return true;
}
bool MergeLodTensorOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto x = op_desc.Input("X").front();
auto mask = op_desc.Input("Mask").front();
auto in_true = op_desc.Input("InTrue").front();
auto in_false = op_desc.Input("InFalse").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.mask = scope->FindVar(mask)->GetMutable<lite::Tensor>();
param_.in_true = scope->FindVar(in_true)->GetMutable<lite::Tensor>();
param_.in_false = scope->FindVar(in_false)->GetMutable<lite::Tensor>();
auto out = op_desc.Output("Out").front();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.level = op_desc.GetAttr<int>("level");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(merge_lod_tensor,
paddle::lite::operators::MergeLodTensorOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class MergeLodTensorOpLite : public OpLite {
public:
MergeLodTensorOpLite() {}
explicit MergeLodTensorOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "merge_lod_tensor"; }
private:
mutable MergeLodTensorParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -1056,6 +1056,39 @@ struct SearchGrnnParam {
lite::Tensor* layout_input{};
};
struct SplitLodTensorParam {
const lite::Tensor* x{};
const lite::Tensor* mask{};
lite::Tensor* out_true{};
lite::Tensor* out_false{};
int level{};
};
struct MergeLodTensorParam {
const lite::Tensor* x{};
const lite::Tensor* mask{};
const lite::Tensor* in_true{};
const lite::Tensor* in_false{};
lite::Tensor* out{};
int level{};
};
struct ConditionalBlockParam {
const lite::Tensor* cond{};
std::vector<lite::Tensor*> x{};
std::vector<lite::Tensor*> outs{};
cpp::BlockDesc* sub_block{};
Scope* scope{};
bool is_scalar_condition{};
};
struct CollectFpnProposalsParam {
std::vector<lite::Tensor*> multi_level_rois{};
std::vector<lite::Tensor*> multi_level_scores{};
lite::Tensor* fpn_rois{};
int post_nms_topN{};
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/reduce_prod_op.h"
#include <algorithm>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool ReduceProdOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
return true;
}
bool ReduceProdOpLite::InferShape() const {
auto x = param_.x;
auto out = param_.output;
std::vector<int> dim = param_.dim;
bool reduce_all = param_.reduce_all;
bool keep_dim = param_.keep_dim;
auto x_dims = x->dims();
auto x_rank = x_dims.size();
CHECK_OR_FALSE(x_rank <= 6U);
for (size_t i = 0; i < dim.size(); i++) {
if (dim[i] < 0) {
dim[i] = x_rank + dim[i];
}
CHECK_OR_FALSE(static_cast<size_t>(dim[i]) < x_rank);
}
std::sort(dim.begin(), dim.end());
if (reduce_all || dim.size() == 0) {
if (keep_dim) {
out->Resize({static_cast<int64_t>(x_rank), 1});
} else {
out->Resize({1});
}
} else {
auto dims_vector = x_dims.Vectorize();
if (keep_dim) {
for (size_t i = 0; i < dim.size(); ++i) {
dims_vector[dim[i]] = 1;
}
} else {
const int kDelFlag = -2;
for (size_t i = 0; i < dim.size(); ++i) {
dims_vector[dim[i]] = kDelFlag;
}
dims_vector.erase(
std::remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
}
if (!keep_dim && dims_vector.size() == 0) {
dims_vector.push_back(1);
}
out->Resize(dims_vector);
if (dim.size() > 0 && dim[0] != 0) {
out->set_lod(x->lod());
}
}
return true;
}
bool ReduceProdOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto x = op_desc.Input("X").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
auto output = op_desc.Output("Out").front();
param_.output = scope->FindVar(output)->GetMutable<lite::Tensor>();
param_.dim = op_desc.GetAttr<std::vector<int>>("dim");
param_.keep_dim = op_desc.GetAttr<bool>("keep_dim");
param_.reduce_all = op_desc.GetAttr<bool>("reduce_all");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(reduce_prod, paddle::lite::operators::ReduceProdOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace operators {
class ReduceProdOpLite : public OpLite {
public:
ReduceProdOpLite() {}
explicit ReduceProdOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "reduce_prod"; }
private:
mutable ReduceParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -83,7 +83,6 @@ bool SliceOp::InferShape() const {
if (axes[0] != 0) {
param_.Out->set_lod(param_.X->lod());
}
LOG(INFO) << "infer shape done";
return true;
}
......@@ -162,7 +161,6 @@ bool SliceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
CHECK_EQ(ends_size, param_.axes.size())
<< "The size of ends must be equal to the size of axes.";
}
LOG(INFO) << "attach impl done";
return true;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/split_lod_tensor_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SplitLodTensorOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.mask);
CHECK_OR_FALSE(param_.out_true);
CHECK_OR_FALSE(param_.out_false);
const auto mask_dims = param_.mask->dims();
CHECK_OR_FALSE(mask_dims.size() == 2);
CHECK_OR_FALSE(mask_dims[1] == 1);
return true;
}
bool SplitLodTensorOpLite::InferShape() const {
auto x_dims = param_.x->dims();
param_.out_true->Resize(x_dims);
param_.out_false->Resize(x_dims);
return true;
}
bool SplitLodTensorOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto x = op_desc.Input("X").front();
auto mask = op_desc.Input("Mask").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.mask = scope->FindVar(mask)->GetMutable<lite::Tensor>();
auto out_true = op_desc.Output("OutTrue").front();
auto out_false = op_desc.Output("OutFalse").front();
param_.out_true = scope->FindVar(out_true)->GetMutable<lite::Tensor>();
param_.out_false = scope->FindVar(out_false)->GetMutable<lite::Tensor>();
param_.level = op_desc.GetAttr<int>("level");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(split_lod_tensor,
paddle::lite::operators::SplitLodTensorOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SplitLodTensorOpLite : public OpLite {
public:
SplitLodTensorOpLite() {}
explicit SplitLodTensorOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "split_lod_tensor"; }
private:
mutable SplitLodTensorParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -33,6 +33,7 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_assign_value_compute SRCS assign_value_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_box_clip_compute SRCS box_clip_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_prod_compute SRCS reduce_prod_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_range_compute SRCS range_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
void reduce_prod_n(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int hw_size = height_in * width_in;
int chw_size = channel_in * hw_size;
int data_index, src_index;
for (int c = 0; c < channel_in; ++c) {
for (int h = 0; h < height_in; ++h) {
for (int w = 0; w < width_in; ++w) {
data_index = c * hw_size + h * width_in + w;
dst[data_index] = 1.0;
for (int n = 0; n < num_in; ++n) {
src_index = n * chw_size + data_index;
dst[data_index] *= src[src_index];
}
}
}
}
}
void reduce_prod_c(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int hw_size = height_in * width_in;
int chw_size = hw_size * channel_in;
int data_index, src_index0, src_index;
for (int n = 0; n < num_in; ++n) {
for (int h = 0; h < height_in; ++h) {
for (int w = 0; w < width_in; ++w) {
data_index = n * hw_size + h * width_in + w;
src_index0 = n * chw_size + h * width_in + w;
dst[data_index] = 1.0;
for (int c = 0; c < channel_in; ++c) {
src_index = src_index0 + c * hw_size;
dst[data_index] *= src[src_index];
}
}
}
}
}
void reduce_prod_h(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int cw_size = channel_in * width_in;
int chw_size = cw_size * height_in;
int hw_size = height_in * width_in;
int data_index, src_index, src_index0;
for (int n = 0; n < num_in; ++n) {
for (int c = 0; c < channel_in; ++c) {
for (int w = 0; w < width_in; ++w) {
data_index = n * cw_size + c * width_in + w;
src_index0 = n * chw_size + c * hw_size + w;
dst[data_index] = 1.0;
for (int h = 0; h < height_in; ++h) {
src_index = src_index0 + h * width_in;
dst[data_index] *= src[src_index];
}
}
}
}
}
void reduce_prod_w(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
int ch_size = channel_in * height_in;
int hw_size = height_in * width_in;
int chw_size = ch_size * width_in;
int data_index = 0;
int src_index0 = 0;
int src_index = 0;
for (int n = 0; n < num_in; ++n) {
for (int c = 0; c < channel_in; ++c) {
for (int h = 0; h < height_in; ++h) {
data_index = n * ch_size + c * height_in + h;
src_index0 = n * chw_size + c * hw_size + h * width_in;
dst[data_index] = 1.0;
for (int w = 0; w < width_in; ++w) {
src_index = src_index0 + w;
dst[data_index] *= src[src_index];
}
}
}
}
}
void reduce_prod_all(const float* src, float* dst, int64_t total_num) {
dst[0] = 1.0;
for (int64_t n = 0; n < total_num; ++n) {
dst[0] *= src[n];
}
}
void reduce_prod_nc(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
// reduce n first.
DDimLite ddimA({1, channel_in, height_in, width_in});
lite::Tensor tensor_tmp;
tensor_tmp.Resize(ddimA);
float* tmp_out = tensor_tmp.mutable_data<float>();
reduce_prod_n(src, tmp_out, num_in, channel_in, height_in, width_in);
reduce_prod_c(tmp_out, dst, 1, channel_in, height_in, width_in);
}
void reduce_prod_ch(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
// reduce c first
DDimLite ddimA({num_in, 1, height_in, width_in});
lite::Tensor tensor_tmp;
tensor_tmp.Resize(ddimA);
float* tmp_out = tensor_tmp.mutable_data<float>();
reduce_prod_c(src, tmp_out, num_in, channel_in, height_in, width_in);
reduce_prod_h(tmp_out, dst, num_in, 1, height_in, width_in);
}
void reduce_prod_hw(const float* src,
float* dst,
int num_in,
int channel_in,
int height_in,
int width_in) {
// reduce h first
DDimLite ddimA({num_in, channel_in, 1, width_in});
lite::Tensor tensor_tmp;
tensor_tmp.Resize(ddimA);
float* tmp_out = tensor_tmp.mutable_data<float>();
reduce_prod_h(src, tmp_out, num_in, channel_in, height_in, width_in);
reduce_prod_w(tmp_out, dst, num_in, channel_in, 1, width_in);
}
class ReduceProdComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "x";
std::string output_ = "out";
std::vector<int> dim_{};
bool keep_dim_{};
DDim x_dims_{};
bool reduce_all_{};
public:
ReduceProdComputeTester(const Place& place,
const std::string& alias,
std::vector<int> dim,
bool keep_dim,
DDim x_dims,
bool reduce_all)
: TestCase(place, alias),
dim_(dim),
keep_dim_(keep_dim),
x_dims_(x_dims),
reduce_all_(reduce_all) {}
void RunBaseline(Scope* scope) override {
auto* x = scope->FindMutableTensor(input_);
auto* x_data = x->data<float>();
auto x_rank = x_dims_.size();
auto* out = scope->NewTensor(output_);
if (!dim_.empty()) {
for (size_t i = 0; i < dim_.size(); i++) {
if (dim_[i] < 0) {
dim_[i] += x_rank;
}
}
}
sort(dim_.begin(), dim_.end());
if (reduce_all_ || dim_.size() == 0) {
if (keep_dim_) {
out->Resize({static_cast<int64_t>(x_rank), 1});
} else {
out->Resize({1});
}
} else {
std::vector<int64_t> out_dims;
for (size_t i = 0; i < x_dims_.size(); i++) {
out_dims.push_back(x_dims_[i]);
}
if (keep_dim_) {
for (size_t i = 0; i < dim_.size(); ++i) {
out_dims[dim_[i]] = 1L;
}
} else {
int64_t kDelFlag = -2;
for (size_t i = 0; i < dim_.size(); ++i) {
out_dims[dim_[i]] = kDelFlag;
}
out_dims.erase(remove(out_dims.begin(), out_dims.end(), kDelFlag),
out_dims.end());
}
if (!keep_dim_ && out_dims.empty()) {
out_dims.push_back(1);
}
out->Resize(out_dims);
}
auto* out_data = out->mutable_data<float>();
if (reduce_all_ || dim_.empty()) {
reduce_prod_all(x_data, out_data, x_dims_.production());
} else {
CHECK_EQ(x_rank, 4U);
int in_n = x_dims_[0];
int in_c = x_dims_[1];
int in_h = x_dims_[2];
int in_w = x_dims_[3];
if (dim_.size() == 1) {
switch (dim_[0]) {
case 0:
reduce_prod_n(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 1:
reduce_prod_c(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 2:
reduce_prod_h(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 3:
reduce_prod_w(x_data, out_data, in_n, in_c, in_h, in_w);
break;
default:
LOG(FATAL) << "error!!!";
}
} else if (dim_.size() == 2) {
if (dim_[0] == 0 && dim_[1] == 1) {
reduce_prod_nc(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_[0] == 1 && dim_[1] == 2) {
reduce_prod_ch(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_[0] == 2 && dim_[1] == 3) {
reduce_prod_hw(x_data, out_data, in_n, in_c, in_h, in_w);
} else {
LOG(FATAL) << "invalid dims_!!";
}
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("reduce_prod");
op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("dim", dim_);
op_desc->SetAttr("keep_dim", keep_dim_);
op_desc->SetAttr("reduce_all", reduce_all_);
}
void PrepareData() override {
std::vector<float> data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
data[i] = (i + 1) * 1.0;
}
SetCommonTensor(input_, x_dims_, data.data());
}
};
void test_reduce_prod(Place place) {
std::vector<std::vector<int>> reduce_dim{
{0}, {1}, {2}, {3}, {0, 1}, {1, 2}, {2, 3}, {-2, -1}};
for (auto n : {1, 3}) {
for (auto c : {1, 2}) {
for (auto h : {1, 3}) {
for (auto w : {1, 3}) {
for (bool keep_dim : {false, true}) {
for (auto dim : reduce_dim) {
auto x_dims = DDim(std::vector<int64_t>({n, c, h, w}));
std::unique_ptr<arena::TestCase> tester(
new ReduceProdComputeTester(
place, "def", dim, keep_dim, x_dims, false));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
}
}
std::vector<int> dim = {0};
bool keep_dim = false;
bool reduce_all = true;
auto x_dims = DDim({2, 2});
std::unique_ptr<arena::TestCase> tester(new ReduceProdComputeTester(
place, "def", dim, keep_dim, x_dims, reduce_all));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
TEST(ReduceProd, precision) {
// #ifdef LITE_WITH_X86
// Place place(TARGET(kX86));
// #endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_reduce_prod(place);
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册