未验证 提交 d18c8a40 编写于 作者: Z zhangshijin 提交者: GitHub

Merge pull request #36 from Cambricon/yolo3-op_first-conv_3dim-iput-mlu

add 1)Yolo3 op 2)first conv 2)3dim iput mlu
......@@ -38,6 +38,14 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
Env<TARGET(kMLU)>::Init();
mlu_core_version_ = config.mlu_core_version();
mlu_core_number_ = config.mlu_core_number();
use_first_conv_ = config.use_first_conv();
mean_vec_ = config.mean();
std_vec_ = config.std();
lite::DeviceInfo::Global().SetMLURunMode(mlu_core_version_,
mlu_core_number_,
use_first_conv_,
mean_vec_,
std_vec_);
#endif // LITE_WITH_MLU
auto places = config.valid_places();
std::vector<std::string> passes{};
......@@ -87,9 +95,6 @@ std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
void CxxPaddleApiImpl::Run() {
#ifdef LITE_WITH_ARM
lite::DeviceInfo::Global().SetRunMode(mode_, threads_);
#endif
#ifdef LITE_WITH_MLU
lite::DeviceInfo::Global().SetMLURunMode(mlu_core_version_, mlu_core_number_);
#endif
raw_predictor_.Run();
}
......
......@@ -108,6 +108,9 @@ class LITE_API PaddlePredictor {
lite_api::PowerMode mode_{lite_api::LITE_POWER_NO_BIND};
lite_api::MLUCoreVersion mlu_core_version_{lite_api::MLU_270};
int mlu_core_number_{1};
bool use_first_conv_{false};
std::vector<float> mean_vec_;
std::vector<float> std_vec_;
};
/// Base class for all the configs.
......
......@@ -69,6 +69,9 @@ thread_local int64_t DeviceInfo::count_ = 0;
#ifdef LITE_WITH_MLU
thread_local cnmlCoreVersion_t DeviceInfo::mlu_core_version_{CNML_MLU270};
thread_local int DeviceInfo::mlu_core_number_{1};
thread_local bool DeviceInfo::use_first_conv_{false};
thread_local std::vector<float> DeviceInfo::mean_vec_;
thread_local std::vector<float> DeviceInfo::std_vec_;
#endif
#ifdef TARGET_IOS
......@@ -1087,7 +1090,10 @@ int DeviceInfo::Setup() {
#ifdef LITE_WITH_MLU
void DeviceInfo::SetMLURunMode(lite_api::MLUCoreVersion core_version,
int core_number) {
int core_number,
bool use_first_conv,
const std::vector<float>& mean_vec,
const std::vector<float>& std_vec) {
switch (core_version) {
case (lite_api::MLUCoreVersion::MLU_220):
mlu_core_version_ = CNML_MLU220;
......@@ -1100,11 +1106,21 @@ void DeviceInfo::SetMLURunMode(lite_api::MLUCoreVersion core_version,
break;
}
mlu_core_number_ = core_number;
use_first_conv_ = use_first_conv;
mean_vec_ = mean_vec;
std_vec_ = std_vec;
}
cnmlCoreVersion_t DeviceInfo::MLUCoreVersion() { return mlu_core_version_; }
int DeviceInfo::MLUCoreNumber() { return mlu_core_number_; }
bool DeviceInfo::UseFirstConv() { return use_first_conv_; }
const std::vector<float>& DeviceInfo::MeanVec() const { return mean_vec_; }
const std::vector<float>& DeviceInfo::StdVec() const { return std_vec_; }
#endif // LITE_WITH_MLU
void DeviceInfo::SetRunMode(lite_api::PowerMode mode, int thread_num) {
......
......@@ -56,9 +56,16 @@ class DeviceInfo {
void SetRunMode(lite_api::PowerMode mode, int thread_num);
#ifdef LITE_WITH_MLU
void SetMLURunMode(lite_api::MLUCoreVersion core_version, int core_number);
void SetMLURunMode(lite_api::MLUCoreVersion core_version,
int core_number,
bool use_first_conv,
const std::vector<float>& mean_vec,
const std::vector<float>& std_vec);
cnmlCoreVersion_t MLUCoreVersion();
int MLUCoreNumber();
bool UseFirstConv();
const std::vector<float>& MeanVec() const;
const std::vector<float>& StdVec() const;
#endif
void SetCache(int l1size, int l2size, int l3size);
void SetArch(ARMArch arch) { arch_ = arch; }
......@@ -114,6 +121,9 @@ class DeviceInfo {
#ifdef LITE_WITH_MLU
static thread_local cnmlCoreVersion_t mlu_core_version_;
static thread_local int mlu_core_number_;
static thread_local bool use_first_conv_;
static thread_local std::vector<float> mean_vec_;
static thread_local std::vector<float> std_vec_;
#endif
void SetDotInfo(int argc, ...);
......
......@@ -15,7 +15,6 @@
#include "lite/core/mir/mlu_postprocess_pass.h"
#include <list>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
......@@ -50,10 +49,9 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
op_desc.SetAttr<int>("out_dtype", 4); // FP16
op_desc.SetInput("X", {cur_node->AsArg().name});
op_desc.SetOutput("Out", {cast_arg_name});
} else if (op_type == "transpose") {
} else if (op_type == "layout") {
// NCHW -> NHWC
op_desc.SetAttr<std::vector<int>>("axis", {0, 2, 3, 1});
op_desc.SetInput("X", {cur_node->AsArg().name});
op_desc.SetInput("Input", {cur_node->AsArg().name});
op_desc.SetOutput("Out", {cast_arg_name});
} else if (op_type == "io_copy") {
op_desc.SetInput("Input", {cur_node->AsArg().name});
......@@ -72,8 +70,13 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
if (PrecisionCompatibleTo(*in_arg_ty, *cur_node->AsArg().type)) {
is_found = true;
}
} else if (op_type == "transpose") {
is_found = true;
} else if (op_type == "layout") {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (DataLayoutCompatible(*in_arg_ty, *cur_node->AsArg().type) &&
DataLayoutCompatible(*out_arg_ty, *cast_type)) {
is_found = true;
}
} else if (op_type == "io_copy") {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
......@@ -89,8 +92,13 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
// we pick the kernel
cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op);
auto& stmt = cast_inst->AsStmt();
stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(stmt.picked_kernel().target()));
if (op_type == "layout") {
stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(TARGET(kX86)));
} else {
stmt.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
stmt.picked_kernel().target()));
}
break;
}
}
......@@ -127,10 +135,9 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
op_desc.SetAttr<int>("out_dtype", 5); // FP16
op_desc.SetInput("X", {cast_arg_name});
op_desc.SetOutput("Out", {cur_node->AsArg().name});
} else if (op_type == "transpose") {
} else if (op_type == "layout") {
// NHWC -> NCHW
op_desc.SetAttr<std::vector<int>>("axis", {0, 3, 1, 2});
op_desc.SetInput("X", {cast_arg_name});
op_desc.SetInput("Input", {cast_arg_name});
op_desc.SetOutput("Out", {cur_node->AsArg().name});
} else if (op_type == "io_copy") {
op_desc.SetInput("Input", {cast_arg_name});
......@@ -151,8 +158,13 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
if (PrecisionCompatibleTo(*in_arg_ty, *cast_type)) {
is_found = true;
}
} else if (op_type == "transpose") {
is_found = true;
} else if (op_type == "layout") {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (DataLayoutCompatible(*in_arg_ty, *cast_type) &&
DataLayoutCompatible(*out_arg_ty, *cur_node->AsArg().type)) {
is_found = true;
}
} else if (op_type == "io_copy") {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
......@@ -168,8 +180,13 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
// we pick the kernel
cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op);
auto& stmt = cast_inst->AsStmt();
stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(stmt.picked_kernel().target()));
if (op_type == "layout") {
stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(TARGET(kX86)));
} else {
stmt.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
stmt.picked_kernel().target()));
}
break;
}
}
......@@ -193,12 +210,16 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
auto* cur_node = head_node;
const auto name_prefix =
head_node->AsArg().name + string_format("_%p", inst_node) + "/trans_";
bool is_first_conv_head =
std::find(first_conv_nodes_.begin(),
first_conv_nodes_.end(),
head_node->AsArg().name) != first_conv_nodes_.end();
// layout cast node
if (head_type->layout() != inst_type->layout()) {
cur_node = InsertCastBefore(
"transpose",
name_prefix + "transpose",
"layout",
name_prefix + "layout",
graph,
cur_node,
inst_node,
......@@ -207,7 +228,7 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
}
// precision cast node
if (head_type->precision() != inst_type->precision()) {
if (head_type->precision() != inst_type->precision() && !is_first_conv_head) {
cur_node = InsertCastBefore(
"cast",
name_prefix + "cast",
......@@ -346,8 +367,8 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
// layout cast node
if (tail_type->layout() != inst_type->layout()) {
cur_node = InsertCastAfter(
"transpose",
name_prefix + "transpose",
"layout",
name_prefix + "layout",
graph,
cur_node,
inst_node,
......@@ -415,6 +436,49 @@ void MLUPostprocessPass::RecreateOp(Node* inst_node, SSAGraph* graph) {
}
}
bool MLUPostprocessPass::IsFirstConvInSubgraph(Node* arg_node, Node* inst) {
auto* block_desc =
static_cast<operators::SubgraphOp*>(inst->AsStmt().op().get())
->GetSubBlock();
for (int op_idx = 0; op_idx < block_desc->OpsSize(); op_idx++) {
auto op_desc = block_desc->GetOp<cpp::OpDesc>(op_idx);
CHECK(op_desc);
if (op_desc->Type() == "conv2d") {
for (auto& names : op_desc->inputs()) {
if (std::find(names.second.begin(),
names.second.end(),
arg_node->AsArg().name) != names.second.end()) {
return true;
}
}
}
}
return false;
}
bool MLUPostprocessPass::IsFirstConvNode(Node* arg_node) {
CHECK(arg_node->IsArg());
for (auto& inst : arg_node->outlinks) {
if (inst->AsStmt().op_type() == "subgraph") {
return IsFirstConvInSubgraph(arg_node, inst);
}
}
return false;
}
void MLUPostprocessPass::GatherFirstConvNodes(SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) {
if (!node.IsStmt()) continue;
if (node.AsStmt().op_type() == "feed") {
for (auto& out : node.outlinks) {
if (IsFirstConvNode(out)) {
first_conv_nodes_.insert(out->AsArg().name);
}
}
}
}
}
void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) {
if (!node.IsStmt()) continue;
......@@ -469,6 +533,10 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Thus here we change these args' layout to NHWC
ModifyLayout(graph.get());
if (lite::DeviceInfo::Global().UseFirstConv()) {
GatherFirstConvNodes(graph.get());
}
// insert io_copy, layout and precision cast of subgraph's inputs and outputs
for (auto& node : graph->mutable_nodes()) {
if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") {
......
......@@ -15,6 +15,7 @@
#pragma once
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "lite/core/mir/pass.h"
......@@ -107,6 +108,15 @@ class MLUPostprocessPass : public ProgramPass {
const Type* cast_type);
void RecreateOp(Node* inst_node, SSAGraph* graph);
void GatherFirstConvNodes(SSAGraph* graph);
bool IsFirstConvNode(Node* arg_node);
bool IsFirstConvInSubgraph(Node* arg_node, Node* inst);
private:
std::set<std::string> first_conv_nodes_;
};
} // namespace mir
......
......@@ -53,7 +53,7 @@ class SubgraphCastDisplayPass : public DebugPass {
for (auto p_in_stmt_node : p_in_arg_node->inlinks) {
CHECK(p_in_stmt_node->IsStmt());
std::string stmt_op_type = p_in_stmt_node->AsStmt().op_type();
if (stmt_op_type == "cast" || stmt_op_type == "transpose" ||
if (stmt_op_type == "cast" || stmt_op_type == "layout" ||
stmt_op_type == "io_copy") {
display_debug_info(*p_in_stmt_node, stmt_op_type, true, false);
} else {
......@@ -76,7 +76,7 @@ class SubgraphCastDisplayPass : public DebugPass {
for (auto p_out_stmt_node : p_out_arg_node->outlinks) {
CHECK(p_out_stmt_node->IsStmt());
std::string stmt_op_type = p_out_stmt_node->AsStmt().op_type();
if (stmt_op_type == "cast" || stmt_op_type == "transpose" ||
if (stmt_op_type == "cast" || stmt_op_type == "layout" ||
stmt_op_type == "io_copy") {
display_debug_info(*p_out_stmt_node, stmt_op_type, false, true);
} else {
......
......@@ -116,12 +116,12 @@ class Optimizer {
"argument_type_display_pass",
"mlu_subgraph_pass",
"mlu_postprocess_pass",
// subgraph_cast_display_pass
"runtime_context_assign_pass",
"argument_type_display_pass",
"mlu_postprocess_pass",
"memory_optimize_pass"}};
if (passes.size() == 1) {
......
......@@ -6,3 +6,4 @@ add_subdirectory(bridges)
add_kernel(subgraph_compute_mlu MLU basic SRCS subgraph_compute.cc DEPS ${lite_kernel_deps} ${mlu_subgraph_bridges})
add_kernel(io_copy_compute_mlu MLU basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps} ${math_mlu})
add_kernel(calib_compute_mlu MLU basic SRCS calib_compute.cc DEPS ${lite_kernel_deps} ${math_mlu})
add_kernel(layout_compute_mlu MLU basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} ${math_mlu})
......@@ -15,6 +15,10 @@ lite_cc_library(subgraph_bridge_elementwise_ops_mlu SRCS elementwise_ops.cc DEPS
lite_cc_library(subgraph_bridge_pool_op_mlu SRCS pool_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_softmax_op_mlu SRCS softmax_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_fc_op_mlu SRCS fc_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_scale_op_mlu SRCS scale_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_interp_op_mlu SRCS interpolate_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_concat_op_mlu SRCS concat_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_transpose_op_mlu SRCS transpose_op.cc DEPS ${subgraph_bridge_deps_mlu})
set(mlu_subgraph_bridges
subgraph_bridge_registry
subgraph_bridge_utility_mlu
......@@ -25,7 +29,11 @@ set(mlu_subgraph_bridges
subgraph_bridge_pool_op_mlu
subgraph_bridge_softmax_op_mlu
subgraph_bridge_fc_op_mlu
subgraph_bridge_transpose_op_mlu
subgraph_bridge_batch_norm_op_mlu
subgraph_bridge_scale_op_mlu
subgraph_bridge_interp_op_mlu
subgraph_bridge_concat_op_mlu
CACHE INTERNAL "mlu_subgraph_bridges")
lite_cc_library(subgraph_test_helper_mlu SRCS test_helper.cc DEPS ${mlu_subgraph_bridges})
......@@ -36,5 +44,8 @@ lite_cc_test(test_elementwise_converter_mlu SRCS elementwise_ops_test.cc DEPS sc
lite_cc_test(test_pool_converter_mlu SRCS pool_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_softmax_converter_mlu SRCS softmax_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_fc_converter_mlu SRCS fc_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_scale_converter_mlu SRCS scale_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_interp_converter_mlu SRCS interpolate_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_concat_converter_mlu SRCS concat_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_transpose_converter_mlu SRCS transpose_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
message(STATUS "+++++ mlu_subgraph_bridges: ${mlu_subgraph_bridges}")
......@@ -31,20 +31,34 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
VLOG(3) << "[MLU] Converting " + op_type + "...";
// Create act node and set params from op
auto fp_type = graph->FPType();
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType());
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, fp_type);
CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name);
cnmlActiveFunction_t act_type = OpTypeToCNMLActType(op_type);
cnmlBaseOp_t activation_op;
CNML_CALL(cnmlCreateActiveOp(&activation_op,
act_type,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor()));
if (op_type == "leaky_relu") {
auto alpha = op_info->GetAttr<float>("alpha");
std::vector<int64_t> shape = {1, 1, 1, 1};
std::string alpha_var_name = string_format("leaky_relu_alpha_%p", op);
auto alpha_tensor =
graph->AddNode(alpha_var_name, shape, CNML_CONST, CNML_NHWC, fp_type);
graph->BindConstRawData(alpha_var_name, &alpha, 1, true);
CNML_CALL(cnmlCreatePreluOp(&activation_op,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor(),
alpha_tensor->mlu_tensor()));
} else {
cnmlActiveFunction_t act_type = OpTypeToCNMLActType(op_type);
CNML_CALL(cnmlCreateActiveOp(&activation_op,
act_type,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor()));
}
graph->FuseOp(activation_op);
return SUCCESS;
}
......@@ -59,3 +73,6 @@ REGISTER_SUBGRAPH_BRIDGE(sigmoid,
paddle::lite::subgraph::mlu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(relu, kMLU, paddle::lite::subgraph::mlu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(tanh, kMLU, paddle::lite::subgraph::mlu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(leaky_relu,
kMLU,
paddle::lite::subgraph::mlu::ActConverter);
......@@ -134,7 +134,7 @@ void test_act(std::vector<int64_t> x_shape, std::string op_type) {
TEST(MLUBridges, activation) {
std::vector<std::vector<int64_t>> shapes{{1}, {2, 3}, {1, 2, 3, 4}};
std::vector<std::string> types{"sigmoid", "relu", "tanh"};
std::vector<std::string> types{"sigmoid", "relu", "tanh", "leaky_relu"};
for (auto x_shape : shapes) {
for (auto op_type : types) {
test_act(x_shape, op_type);
......@@ -150,3 +150,4 @@ TEST(MLUBridges, activation) {
USE_SUBGRAPH_BRIDGE(sigmoid, kMLU)
USE_SUBGRAPH_BRIDGE(relu, kMLU)
USE_SUBGRAPH_BRIDGE(tanh, kMLU)
USE_SUBGRAPH_BRIDGE(leaky_relu, kMLU)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X");
auto out_var_name = op_info->Output("Out").front();
auto param_axis = op_info->GetAttr<int>("axis");
// auto x = scope->FindVar(x_var_name[0])->GetMutable<Tensor>();
auto input_num = x_var_name.size();
std::vector<cnmlTensor_t> input_tensor;
std::vector<std::vector<int64_t>> input_dims;
for (auto x_name : x_var_name) {
CHECK(graph->HasNode(x_name));
input_tensor.push_back(graph->GetNode(x_name)->mlu_tensor());
auto x = scope->FindVar(x_name)->GetMutable<Tensor>();
input_dims.push_back(x->dims().Vectorize());
}
auto dims = input_dims[0].size();
int axis = (param_axis < 0) ? (param_axis + dims) : param_axis;
int nhwc_axis = -1;
if (dims == 4) {
int nchw_to_nhwc_axis_map[4] = {0, 3, 1, 2};
nhwc_axis = nchw_to_nhwc_axis_map[axis];
} else if (dims == 3) {
int nchw_to_nhwc_axis_map[3] = {0, 2, 1};
nhwc_axis = nchw_to_nhwc_axis_map[axis];
} else {
CHECK(0) << "Unsupport dims in mlu concat";
}
std::vector<int64_t> output_dims;
output_dims.assign(dims, 0);
/* std::cout << string_format("concat axis: %d(NCHW), %d(NHWC)", axis,
* nhwc_axis) << std::endl; */
for (int i = 0; i < output_dims.size(); ++i) {
if (i == nhwc_axis) {
for (auto& dim : input_dims) output_dims[i] += dim[i];
} else {
output_dims[i] = input_dims[0][i];
}
}
/* std::cout << string_format("concat output dim: %ld, %ld, %ld, %ld") <<
* std::endl; */
auto* output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
output->Resize(output_dims);
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType());
cnmlBaseOp_t concat_op;
cnmlTensor_t outputs[1];
outputs[0] = output_tensor->mlu_tensor();
CNML_CALL(cnmlCreateNdConcatOp(
&concat_op, nhwc_axis, input_tensor.data(), input_num, outputs, 1));
graph->FuseOp(concat_op);
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(concat,
kMLU,
paddle::lite::subgraph::mlu::ConcatConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/concat_op.h"
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
void concat_ref(const std::shared_ptr<operators::ConcatOpLite> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = op_info->Input("X");
std::vector<lite::Tensor*> inputs;
for (auto var : x) {
inputs.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
int axis = op_info->GetAttr<int>("axis");
std::vector<lite::Tensor*> inputs_concat(inputs.size());
for (int j = 0; j < inputs.size(); ++j) {
inputs_concat[j] = inputs[j];
}
size_t num = inputs.size();
int rows = 1;
auto dim_0 = inputs[0]->dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;
std::vector<int64_t> inputs_cols(inputs.size());
for (int i = 0; i < num; ++i) {
int t_cols = inputs[i]->numel() / rows;
out_cols += t_cols;
inputs_cols[i] = t_cols;
}
for (int k = 0; k < out_rows; ++k) {
float* dst_ptr = out->mutable_data<float>() + k * out_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = inputs_cols[j];
const float* src_prt = inputs[j]->data<float>() + k * col_len;
std::memcpy(dst_ptr + col_idx, src_prt, sizeof(float) * col_len);
col_idx += col_len;
}
}
}
void test_concat(std::vector<std::vector<int64_t>> input, int axis) {
std::string x_var_name = "x";
std::string y_var_name = "y";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
// prepare input&output variables
Scope scope;
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* y = scope.Var(y_var_name)->GetMutable<Tensor>();
x->Resize(DDim(input[0]));
y->Resize(DDim(input[1]));
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
CHECK_EQ(out->dims(), out_ref->dims());
// initialize input&output data
FillTensor<float>(x);
FillTensor<float>(y);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("concat");
opdesc.SetInput("X", {x_var_name, y_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("axis", axis);
auto op = CreateOp<operators::ConcatOpLite>(opdesc, &scope);
concat_ref(op);
out_ref->CopyDataFrom(*out);
Tensor input_x, input_y;
input_x.Resize(DDim(input[0]));
input_y.Resize(DDim(input[1]));
transpose(x->mutable_data<float>(),
input_x.mutable_data<float>(),
{static_cast<int>(input[0][0]),
static_cast<int>(input[0][1]),
static_cast<int>(input[0][2]),
static_cast<int>(input[0][3])},
{0, 2, 3, 1});
transpose(y->mutable_data<float>(),
input_y.mutable_data<float>(),
{static_cast<int>(input[1][0]),
static_cast<int>(input[1][1]),
static_cast<int>(input[1][2]),
static_cast<int>(input[1][3])},
{0, 2, 3, 1});
auto os = out->dims();
out->Resize({static_cast<int>(os[0]),
static_cast<int>(os[2]),
static_cast<int>(os[3]),
static_cast<int>(os[1])});
x->CopyDataFrom(input_x);
y->CopyDataFrom(input_y);
x->Resize({static_cast<int>(input[0][0]),
static_cast<int>(input[0][2]),
static_cast<int>(input[0][3]),
static_cast<int>(input[0][1])});
y->Resize({static_cast<int>(input[1][0]),
static_cast<int>(input[1][2]),
static_cast<int>(input[1][3]),
static_cast<int>(input[1][1])});
LaunchOp(op, {x_var_name, y_var_name}, {out_var_name});
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
Tensor output_trans;
output_trans.Resize(out->dims());
transpose(out_data,
output_trans.mutable_data<float>(),
{static_cast<int>(os[0]),
static_cast<int>(os[2]),
static_cast<int>(os[3]),
static_cast<int>(os[1])},
{0, 3, 1, 2});
out_data = output_trans.mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data[i], out_ref_data[i], 5e-4);
}
}
TEST(MLUBridges, concat) {
test_concat({{3, 3, 5, 2}, {2, 3, 5, 2}}, 0);
test_concat({{3, 5, 5, 2}, {3, 1, 5, 2}}, 1);
test_concat({{3, 3, 2, 2}, {3, 3, 4, 2}}, 2);
test_concat({{3, 3, 5, 2}, {3, 3, 5, 6}}, 3);
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(concat, kMLU);
......@@ -119,14 +119,6 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
LOG(FATAL) << "UnSupported weight precision!";
}
cnmlConvOpParam_t conv_param;
CNML_CALL(cnmlCreateConvOpParam(&conv_param,
strides[0],
strides[1],
dilations[0],
dilations[1],
paddings[0] * 2,
paddings[2] * 2));
std::string bias_var_name;
std::shared_ptr<MLUTensor> bias_tensor;
if (HasInputArg(op_info, scope, "Bias")) {
......@@ -160,15 +152,75 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
graph->FPType());
graph->BindConstData(bias_var_name, bias);
}
cnmlBaseOp_t conv_op;
const auto input_scale = op_info->GetAttr<float>("input_scale");
CNML_CALL(cnmlCreateConvOpForward(
&conv_op,
conv_param,
graph->GetNode(input_var_name)->mlu_tensor(),
output_tensor->mlu_tensor(),
filter_tensor->mlu_tensor(),
bias_tensor ? bias_tensor->mlu_tensor() : nullptr));
bool use_first_conv = false;
if (lite::DeviceInfo::Global().UseFirstConv() && input_dims_nhwc[3] == 3) {
use_first_conv = true;
}
cnmlBaseOp_t conv_op;
if (use_first_conv) {
cnmlConvFirstOpParam_t conv_param;
CNML_CALL(cnmlCreateConvFirstOpParam_V2(&conv_param,
strides[0],
strides[1],
dilations[0],
dilations[1],
paddings[2],
paddings[2],
paddings[0],
paddings[0]));
const auto mean_tensor = graph->AddNode("first_conv_mean_tensor",
std::vector<int64_t>{3},
CNML_CONST,
CNML_CNHW,
graph->FPType());
const auto std_tensor = graph->AddNode("first_conv_std_tensor",
std::vector<int64_t>{3},
CNML_CONST,
CNML_CNHW,
graph->FPType());
graph->BindConstRawData("first_conv_mean_tensor",
lite::DeviceInfo::Global().MeanVec().data(),
3,
false);
graph->BindConstRawData("first_conv_std_tensor",
lite::DeviceInfo::Global().StdVec().data(),
3,
false);
graph->GetNode(input_var_name)->set_mlu_dtype(CNML_DATA_UINT8);
CNML_CALL(cnmlCreateConvFirstOpForward(
&conv_op,
conv_param,
graph->GetNode(input_var_name)->mlu_tensor(),
mean_tensor->mlu_tensor(),
output_tensor->mlu_tensor(),
filter_tensor->mlu_tensor(),
bias_tensor ? bias_tensor->mlu_tensor() : nullptr,
std_tensor->mlu_tensor()));
CNML_CALL(cnmlDestroyConvFirstOpParam(&conv_param));
} else {
cnmlConvOpParam_t conv_param;
CNML_CALL(cnmlCreateConvOpParam(&conv_param,
strides[0],
strides[1],
dilations[0],
dilations[1],
paddings[0] * 2,
paddings[2] * 2));
CNML_CALL(cnmlCreateConvOpForward(
&conv_op,
conv_param,
graph->GetNode(input_var_name)->mlu_tensor(),
output_tensor->mlu_tensor(),
filter_tensor->mlu_tensor(),
bias_tensor ? bias_tensor->mlu_tensor() : nullptr));
CNML_CALL(cnmlDestroyConvOpParam(&conv_param));
}
graph->SetComputingDataType(
conv_op, graph->GetNode(input_var_name)->mlu_tensor(), 1 / input_scale);
......@@ -183,7 +235,6 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
graph->BindConstData(filter_var_name, filter);
graph->FuseOp(conv_op);
CNML_CALL(cnmlDestroyConvOpParam(&conv_param));
return REBUILD_WHEN_SHAPE_CHANGED;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto out = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto x_dims = x->dims();
CHECK_EQ(x_dims.size(), 4);
auto scale = op_info->GetAttr<float>("scale");
auto out_w = op_info->GetAttr<int>("out_w");
auto out_h = op_info->GetAttr<int>("out_h");
auto align_corners = op_info->GetAttr<bool>("align_corners");
CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name);
auto in_h = x_dims[1];
auto in_w = x_dims[2];
// Priority: SizeTensor > OutSize > Scale > scale > out_h/out_w
if (HasInputArg(op_info, scope, "SizeTensor")) {
LOG(ERROR) << "Not support SizeTensor input now";
CHECK(0);
} else {
if (HasInputArg(op_info, scope, "Scale")) {
LOG(ERROR) << "Not support Scale input now";
CHECK(0);
}
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
}
if (HasInputArg(op_info, scope, "OutSize")) {
LOG(ERROR) << "Not support OutSize input now";
CHECK(0);
}
}
out->Resize({x_dims[0], out_h, out_w, x_dims[3]});
auto output_tensor = graph->AddNode(out_var_name,
out->dims().Vectorize(),
CNML_TENSOR,
CNML_NHWC,
graph->FPType());
cnmlBaseOp_t interp_op;
/* if (interp_method == "bilinear") { */
/* cnmlInterpOpParam_t interp_param; */
/* CNML_CALL(cnmlCreateInterpOpParam(&interp_param, out_w, out_h,
* align_corners)); */
/* CNML_CALL(cnmlCreateInterpOp(&interp_op, */
/* input_tensor->mlu_tensor(), */
/* output_tensor->mlu_tensor(), */
/* interp_param)); */
/* CNML_CALL(cnmlDestroyInterpOpParam(&interp_param)); */
/* } else if (interp_method == "nearest") { */
cnmlNearestNeighborOpParam_t nn_param;
CNML_CALL(cnmlCreateNearestNeighborOpParam(&nn_param, out_w, out_h));
CNML_CALL(cnmlSetNearestNeighborAlignCorner(&nn_param, align_corners));
CNML_CALL(cnmlCreateNearestNeighborOp(&interp_op,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor(),
nn_param));
CNML_CALL(cnmlDestroyNearestNeighborOpParam(&nn_param));
/* } else { */
/* LOG(WARNING) << "[MLU] Unsupported interpolate method: " <<
* interp_method; */
/* return FAILED; */
/* } */
graph->FuseOp(interp_op);
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(nearest_interp,
kMLU,
paddle::lite::subgraph::mlu::InterpolateConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/interpolate_op.h"
#include <gtest/gtest.h>
#include <string>
#include "lite/core/device_info.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
template <typename dtype>
void ResizeNearestAlign(const lite::Tensor* x,
lite::Tensor* out,
bool with_align) {
auto x_dims = x->dims();
int num = x_dims[0];
int channels = x_dims[1];
int hin = x_dims[2];
int win = x_dims[3];
int hout = out->dims()[2];
int wout = out->dims()[3];
dtype scale_w = (with_align) ? (static_cast<float>(win - 1) / (wout - 1))
: (static_cast<float>(win) / (wout));
dtype scale_h = (with_align) ? (static_cast<float>(hin - 1) / (hout - 1))
: (static_cast<float>(hin) / (hout));
const dtype* src = x->data<dtype>();
dtype* dst = out->mutable_data<dtype>();
int dst_stride_w = 1;
int dst_stride_h = wout;
int dst_stride_c = wout * hout;
int dst_stride_batch = wout * hout * channels;
int src_stride_w = 1;
int src_stride_h = win;
int src_stride_c = win * hin;
int src_stride_batch = win * hin * channels;
for (int n = 0; n < num; ++n) {
for (int c = 0; c < channels; ++c) {
int src_index = n * src_stride_batch + c * src_stride_c;
for (int h = 0; h < hout; ++h) {
for (int w = 0; w < wout; ++w) {
int fw = (with_align) ? static_cast<int>(scale_w * w + 0.5)
: static_cast<int>(scale_w * w);
fw = (fw < 0) ? 0 : fw;
int fh = (with_align) ? static_cast<int>(scale_h * h + 0.5)
: static_cast<int>(scale_h * h);
fh = (fh < 0) ? 0 : fh;
int w_start = static_cast<int>(fw);
int h_start = static_cast<int>(fh);
int dst_index = n * dst_stride_batch + c * dst_stride_c +
h * dst_stride_h + w * dst_stride_w;
dst[dst_index] =
src[src_index + w_start * src_stride_w + h_start * src_stride_h];
}
}
}
}
}
template <typename DType>
void BilinearInterpRef(const lite::Tensor* x,
lite::Tensor* out,
bool align_corners,
int align_mode) {
auto x_dims = x->dims();
int batch_size = x_dims[0];
int channel_size = x_dims[1];
auto x_h = x_dims[2];
auto x_w = x_dims[3];
CHECK_EQ(x_dims.size(), 4);
auto out_dims = out->dims();
int out_h = out_dims[2];
int out_w = out_dims[3];
// copy from x if no change
if (x_h == out_h && x_w == out_w) {
out->CopyDataFrom(*x);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(x_h - 1) / (out_h - 1)
: static_cast<float>(x_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(x_w - 1) / (out_w - 1)
: static_cast<float>(x_w) / out_w;
}
// naive bilinear interpolation
auto x_data = x->data<DType>();
auto out_data = out->mutable_data<DType>();
bool align_flag = (align_mode == 0 && !align_corners);
std::vector<int> vy_n, vy_s;
std::vector<float> vd_n, vd_s;
vy_n.reserve(out_h);
vy_s.reserve(out_h);
vd_n.reserve(out_h);
vd_s.reserve(out_h);
for (int k = 0; k < out_h; k++) {
int yn = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
: static_cast<int>(ratio_h * k);
yn = (yn > 0) ? yn : 0;
int ys = (yn + 1) < (x_h - 1) ? (yn + 1) : (x_h - 1);
float idx_src_y = ratio_h * (k + 0.5) - 0.5;
idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
float dn = align_flag ? idx_src_y - yn : ratio_h * k - yn;
float ds = 1.f - dn;
{
vy_n[k] = yn;
vy_s[k] = ys;
vd_n[k] = dn;
vd_s[k] = ds;
}
}
std::vector<int> vx_w, vx_e;
std::vector<float> vd_w, vd_e;
vx_w.reserve(out_w);
vx_e.reserve(out_w);
vd_w.reserve(out_w);
vd_e.reserve(out_w);
for (int l = 0; l < out_w; l++) {
int xw = align_flag ? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
xw = (xw > 0) ? xw : 0;
int xe = (xw + 1) < (x_w - 1) ? (xw + 1) : (x_w - 1);
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float dw = align_flag ? idx_src_x - xw : ratio_w * l - xw;
float de = 1.f - dw;
{
vx_w[l] = xw;
vx_e[l] = xe;
vd_w[l] = dw;
vd_e[l] = de;
}
}
std::vector<int64_t> x_strides(x_dims.size(), 1);
for (int idx = x_strides.size() - 2; idx >= 0; idx--) {
x_strides[idx] = x_strides[idx + 1] * x_dims[idx + 1];
}
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < channel_size; j++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
DType x0 = x_data[i * x_strides[0] + j * x_strides[1] +
vy_n[k] * x_strides[2] + vx_w[l] * x_strides[3]];
DType x1 = x_data[i * x_strides[0] + j * x_strides[1] +
vy_s[k] * x_strides[2] + vx_w[l] * x_strides[3]];
DType x2 = x_data[i * x_strides[0] + j * x_strides[1] +
vy_n[k] * x_strides[2] + vx_e[l] * x_strides[3]];
DType x3 = x_data[i * x_strides[0] + j * x_strides[1] +
vy_s[k] * x_strides[2] + vx_e[l] * x_strides[3]];
*out_data = x0 * vd_s[k] * vd_e[l] + x1 * vd_n[k] * vd_e[l] +
x2 * vd_s[k] * vd_w[l] + x3 * vd_n[k] * vd_w[l];
out_data++;
}
}
}
}
}
class InterpComputeTester {
protected:
// common attributes for this op.
std::string x_var_name = "X";
std::string outsize_var_name = "OutSize";
std::string out_var_name = "Out";
std::string out_ref_var_name = "out_ref";
DDim dims_{{1, 2, 3, 4}};
Scope scope;
std::string interp_method_ = "nearest";
float scale_ = -1.f;
int out_h_ = -1;
int out_w_ = -1;
bool align_corners_ = true;
int align_mode_ = 1;
bool use_outsize_ = false;
public:
InterpComputeTester(const std::string& alias,
DDim dims,
std::string interp_method = "nearest",
float scale = -1.f,
int out_h = -1,
int out_w = -1,
bool align_corners = true,
int align_mode = 1,
bool use_outsize = false)
: dims_(dims),
interp_method_(interp_method),
scale_(scale),
out_h_(out_h),
out_w_(out_w),
align_corners_(align_corners),
align_mode_(align_mode),
use_outsize_(use_outsize) {}
void Execute(float abs_error) {
cpp::OpDesc op_desc;
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* outsize = scope.Var(outsize_var_name)->GetMutable<Tensor>();
auto* outref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
int out_h = out_h_;
int out_w = out_w_;
if (scale_ > 0) {
out_h = static_cast<int>(dims_[2] * scale_);
out_w = static_cast<int>(dims_[3] * scale_);
}
x->Resize(dims_);
/* printf("----output tensor dims: %ld, %d, %d, %ld\n", dims_[0], out_h,
* out_w, dims_[1]); */
std::vector<int64_t> out_shape_nchw = {dims_[0], dims_[1], out_h, out_w};
out->Resize(DimNCHW2NHWC(out_shape_nchw));
outref->Resize(out_shape_nchw);
outsize->Resize({2});
FillTensor<float, float>(x, -1.f, 1.f);
if (use_outsize_) {
outsize->mutable_data<int>()[0] = out_h;
outsize->mutable_data<int>()[1] = out_w;
outsize->set_persistable(true);
}
if (interp_method_ == "nearest") {
op_desc.SetType("nearest_interp");
} else if (interp_method_ == "bilinear") {
op_desc.SetType("bilinear_interp");
} else {
LOG(FATAL) << "unsupport";
}
op_desc.SetInput("X", {x_var_name});
if (use_outsize_) {
op_desc.SetInput("OutSize", {outsize_var_name});
}
op_desc.SetOutput("Out", {out_var_name});
op_desc.SetAttr("scale", scale_);
op_desc.SetAttr("out_h", out_h_);
op_desc.SetAttr("out_w", out_w_);
op_desc.SetAttr("align_corners", align_corners_);
op_desc.SetAttr("align_mode", align_mode_);
op_desc.SetAttr("interp_method", interp_method_);
auto op = CreateOp<operators::InterpolateOp>(op_desc, &scope);
if (interp_method_ == "nearest") {
ResizeNearestAlign<float>(x, outref, align_corners_);
} else if (interp_method_ == "bilinear") {
BilinearInterpRef<float>(x, outref, align_corners_, align_mode_);
}
int in = dims_[0], ic = dims_[1], ih = dims_[2], iw = dims_[3];
Tensor input_trans;
input_trans.Resize(dims_);
transpose(x->mutable_data<float>(),
input_trans.mutable_data<float>(),
{in, ic, ih, iw},
{0, 2, 3, 1});
x->CopyDataFrom(input_trans);
x->Resize(DimNCHW2NHWC(dims_.Vectorize()));
if (use_outsize_) {
LaunchOp(op, {x_var_name, outsize_var_name}, {out_var_name});
} else {
LaunchOp(op, {x_var_name}, {out_var_name});
}
auto* out_ref_data = outref->mutable_data<float>();
Tensor output_trans;
output_trans.Resize(out_shape_nchw);
transpose(
out->mutable_data<float>(),
output_trans.mutable_data<float>(),
{static_cast<int>(dims_[0]), out_h, out_w, static_cast<int>(dims_[1])},
{0, 3, 1, 2});
auto* out_data = output_trans.mutable_data<float>();
for (int i = 0; i < out->dims().production(); ++i) {
EXPECT_NEAR(out_data[i], out_ref_data[i], abs_error);
}
}
};
void TestInterpOuthw(float abs_error = 2e-5) {
for (auto x_dims : std::vector<std::vector<int64_t>>{{3, 4, 8, 9}}) {
/* for (auto interp_method : std::vector<std::string>{"nearest",
* "bilinear"}) { */
for (auto interp_method : std::vector<std::string>{"nearest"}) {
for (int out_h : {6, 8, 12}) {
for (int out_w : {6, 9}) {
printf("testcase %s: out_w %d, out_h %d\n",
interp_method.c_str(),
out_w,
out_h);
InterpComputeTester tester(
"def", DDim(x_dims), interp_method, -1.f, out_h, out_w);
tester.Execute(abs_error);
}
}
}
}
}
void TestInterpScale(float abs_error = 2e-5) {
for (auto x_dims : std::vector<std::vector<int64_t>>{{3, 4, 8, 9}}) {
/* for (auto interp_method : std::vector<std::string>{"nearest",
* "bilinear"}) { */
for (auto interp_method : std::vector<std::string>{"nearest"}) {
for (float scale : {0.3f, 1.f, 1.7f}) {
printf("testcase %s: scale: %f\n", interp_method.c_str(), scale);
InterpComputeTester tester("def", DDim(x_dims), interp_method, scale);
tester.Execute(abs_error);
}
}
}
}
void TestInterpOutsize(float abs_error = 2e-5) {
for (auto x_dims : std::vector<std::vector<int64_t>>{{3, 4, 8, 9}}) {
/* for (auto interp_method : std::vector<std::string>{"nearest",
* "bilinear"}) { */
for (auto interp_method : std::vector<std::string>{"nearest"}) {
printf("testcase %s: outsize: %d %d\n", interp_method.c_str(), 4, 4);
InterpComputeTester tester(
"def", DDim(x_dims), interp_method, -1, 4, 4, true, 1, true);
tester.Execute(abs_error);
}
}
}
void TestInterpAlignCorners(float abs_error = 2e-5) {
for (auto x_dims : std::vector<std::vector<int64_t>>{{3, 4, 8, 9}}) {
for (bool align_corners : {true, false}) {
printf(
"testcase nearest: scale: 0.4, out_w -1 out_h -1, align_corners %d\n",
align_corners);
InterpComputeTester tester(
"def", DDim(x_dims), "nearest", 0.4, -1, -1, align_corners);
tester.Execute(abs_error);
}
}
}
void TestInterpAlignMode(float abs_error = 2e-5) {
for (auto x_dims : std::vector<std::vector<int64_t>>{{3, 4, 8, 9}}) {
for (bool align_corners : {true, false}) {
for (int align_mode : {0, 1}) {
printf(
"testcase bilinear: scale: 0.7, out_w -1 out_h -1, align_corners "
"%d, mode %d\n",
align_corners,
align_mode);
InterpComputeTester tester("def",
DDim(x_dims),
"bilinear",
0.7,
-1,
-1,
align_corners,
align_mode);
tester.Execute(abs_error);
}
}
}
}
TEST(MLUBridges, interpolate) {
float abs_error = 2e-5;
TestInterpOuthw(abs_error);
TestInterpScale(abs_error);
// bug, not usable
// TestInterpOutsize(abs_error);
TestInterpAlignCorners(abs_error);
// only for bilinear interp
// TestInterpAlignMode(abs_error);
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(nearest_interp, kMLU);
......@@ -22,3 +22,9 @@ USE_SUBGRAPH_BRIDGE(pool2d, kMLU);
USE_SUBGRAPH_BRIDGE(softmax, kMLU);
USE_SUBGRAPH_BRIDGE(batch_norm, kMLU);
USE_SUBGRAPH_BRIDGE(fc, kMLU);
USE_SUBGRAPH_BRIDGE(nearest_interp, kMLU);
USE_SUBGRAPH_BRIDGE(leaky_relu, kMLU);
USE_SUBGRAPH_BRIDGE(transpose, kMLU);
USE_SUBGRAPH_BRIDGE(transpose2, kMLU);
USE_SUBGRAPH_BRIDGE(concat, kMLU);
USE_SUBGRAPH_BRIDGE(scale, kMLU);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
// Create act node and set params from op
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType());
auto bias_after_scale = op_info->GetAttr<bool>("bias_after_scale");
auto scale = op_info->GetAttr<float>("scale");
auto bias = op_info->GetAttr<float>("bias");
auto beta = bias_after_scale ? bias : bias * scale;
std::vector<int64_t> shape = {1, 1, 1, 1};
std::string prefix = string_format("_%p", op);
auto alpha_tensor = graph->AddNode(
"Alpha" + prefix, shape, CNML_CONST, CNML_NHWC, graph->FPType());
auto beta_tensor = graph->AddNode(
"Beta" + prefix, shape, CNML_CONST, CNML_NHWC, graph->FPType());
graph->BindConstRawData("Alpha" + prefix, &scale, 1);
graph->BindConstRawData("Beta" + prefix, &beta, 1);
auto input_tensor = graph->GetNode(x_var_name);
cnmlBaseOp_t scale_op;
CNML_CALL(cnmlCreateScaleOp(&scale_op,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor(),
alpha_tensor->mlu_tensor(),
beta_tensor->mlu_tensor()));
graph->FuseOp(scale_op);
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(scale,
kMLU,
paddle::lite::subgraph::mlu::ScaleConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/scale_op.h"
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
void scale_ref(const std::shared_ptr<operators::ScaleOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
float scale = op_info->GetAttr<float>("scale");
float bias = op_info->GetAttr<float>("bias");
bool bias_after_scale = op_info->GetAttr<bool>("bias_after_scale");
if (!bias_after_scale) {
bias *= scale;
}
auto x_data = x->data<float>();
auto out_data = out->mutable_data<float>();
DDim x_dims = x->dims();
DDim out_dims = out->dims();
CHECK_EQ(x_dims.production(), out_dims.production());
for (int i = 0; i < out_dims.production(); i++) {
out_data[i] = x_data[i] * scale + bias;
}
}
void test_scale(int bs,
int ic,
int ih,
int iw,
bool bias_after_scale,
float scale,
float bias) {
// prepare input&output variables
Scope scope;
std::string x_var_name("x");
std::string out_var_name("out");
std::string out_ref_var_name("out_ref");
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
// initialize input&output data
FillTensor<float, int>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("scale");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("bias_after_scale", bias_after_scale);
opdesc.SetAttr("scale", scale);
opdesc.SetAttr("bias", bias);
// create and convert op to MLU model, then run it on MLU
auto op = CreateOp<operators::ScaleOp>(opdesc, &scope);
scale_ref(op);
out_ref->CopyDataFrom(*out);
Tensor input_trans;
input_trans.Resize({bs, ic, ih, iw});
transpose(x->mutable_data<float>(),
input_trans.mutable_data<float>(),
{bs, ic, ih, iw},
{0, 2, 3, 1});
auto os = out->dims();
out->Resize({static_cast<int>(os[0]),
static_cast<int>(os[2]),
static_cast<int>(os[3]),
static_cast<int>(os[1])});
x->CopyDataFrom(input_trans);
x->Resize({bs, ih, iw, ic});
LaunchOp(op, {x_var_name}, {out_var_name});
// execute reference implementation and save to output tensor('out')
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
Tensor output_trans;
output_trans.Resize(os);
transpose(out_data,
output_trans.mutable_data<float>(),
{static_cast<int>(os[0]),
static_cast<int>(os[2]),
static_cast<int>(os[3]),
static_cast<int>(os[1])},
{0, 3, 1, 2});
out_data = output_trans.mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
TEST(MLUBridges, scale) {
for (auto bs : {1, 3}) {
for (auto ic : {1, 3}) {
for (auto ih : {3, 4}) {
for (auto iw : {4, 3}) {
for (auto bias_after_scale : {false, true}) {
for (auto scale : {-1.0f, 5.0f}) {
for (auto bias : {-2.0f, 30.0f}) {
VLOG(3) << "bs: " << bs << " ic: " << ic << " ih: " << ih
<< " iw: " << iw
// << " bias_after_scale: " << bias_after_scale
<< " scale: " << scale << " bias: " << bias;
test_scale(bs, ic, ih, iw, bias_after_scale, scale, bias);
}
}
}
}
}
}
}
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(scale, kMLU);
......@@ -47,6 +47,8 @@ class MLUTensor {
return mlu_ptr_;
}
void set_mlu_dtype(cnmlDataType_t type) { mlu_dtype_ = type; }
~MLUTensor();
private:
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
std::vector<int> axis_to_nhwc4d(const std::vector<int>& axis) {
CHECK_EQ(axis.size(), 4);
std::vector<int> new_axis(4, 0);
const std::vector<int> axis_map1 = {0, 2, 3, 1};
const std::vector<int> axis_map2 = {0, 3, 1, 2};
for (size_t i = 0; i < new_axis.size(); ++i) {
new_axis[i] = axis_map2[axis[axis_map1[i]]];
}
return new_axis;
}
std::vector<int> axis_to_nhw3d(const std::vector<int>& axis) {
CHECK_EQ(axis.size(), 3);
std::vector<int> new_axis(3, 0);
const std::vector<int> axis_map = {0, 2, 1};
for (size_t i = 0; i < new_axis.size(); ++i) {
new_axis[i] = axis_map[axis[axis_map[i]]];
}
new_axis.push_back(3);
return new_axis;
}
std::vector<int64_t> infer_shape(const std::vector<int64_t>& x_dims,
const std::vector<int>& axis_nhwc) {
std::vector<int64_t> out_dims(x_dims);
for (size_t i = 0; i < out_dims.size(); ++i) {
out_dims[i] = x_dims[axis_nhwc[i]];
}
return out_dims;
}
int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
// Get input vars and op attributes
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto x_dims = x->dims().Vectorize();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
auto axis = op_info->GetAttr<std::vector<int>>("axis");
std::vector<int> axis_nhwc;
if (axis.size() == 4) {
axis_nhwc = axis_to_nhwc4d(axis);
} else if (axis.size() == 3) {
axis_nhwc = axis_to_nhw3d(axis);
} else {
CHECK(0) << "Unsupport dim in mlu transpose";
}
auto output_dims_nhwc = infer_shape(x_dims, axis_nhwc);
output->Resize(output_dims_nhwc);
auto output_tensor = graph->AddNode(
out_var_name, output_dims_nhwc, CNML_TENSOR, CNML_NHWC, graph->FPType());
CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name);
cnmlBaseOp_t transpose_op_{nullptr};
cnmlNdTransposeOpParam_t transpose_param{nullptr};
CNML_CALL(cnmlCreateNdTransposeOpParam(
&transpose_param, axis_nhwc.data(), axis_nhwc.size()));
// Use cnmlCreatexxxOpForward to create op.
CNML_CALL(cnmlCreateNdTransposeProOp(&transpose_op_,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor(),
transpose_param));
graph->FuseOp(transpose_op_);
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(transpose,
kMLU,
paddle::lite::subgraph::mlu::TransposeConverter);
REGISTER_SUBGRAPH_BRIDGE(transpose2,
kMLU,
paddle::lite::subgraph::mlu::TransposeConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/transpose_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int data_index(std::vector<int> pos, DDimLite dims) {
int d1 = dims[1];
int d2 = dims[2];
int d3 = dims[3];
return pos[3] + pos[2] * d3 + pos[1] * d3 * d2 + pos[0] * d3 * d2 * d1;
}
std::vector<int> pos_trans(std::vector<int> in_pos, std::vector<int> axis) {
std::vector<int> out_pos(in_pos.size());
for (int i = 0; i < axis.size(); i++) {
out_pos[axis[i]] = in_pos[i];
}
return out_pos;
}
template <typename dtype>
void transpose_ref(const std::shared_ptr<operators::TransposeOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto input =
scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto output =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto x_dims = input->dims();
auto y_dims = output->dims();
auto axis = op_info->GetAttr<std::vector<int>>("axis");
// auto input_data = input->data<dtype>();
auto* input_data = input->mutable_data<dtype>();
auto* output_data = output->mutable_data<dtype>();
int input_n = x_dims[0];
int input_c = x_dims[1];
int input_h = x_dims[2];
int input_w = x_dims[3];
for (int n = 0; n < input_n; ++n) {
for (int c = 0; c < input_c; ++c) {
for (int h = 0; h < input_h; ++h) {
for (int w = 0; w < input_w; ++w) {
std::vector<int> in_pos{n, c, h, w};
std::vector<int> out_pos = pos_trans(in_pos, axis);
int in_index = data_index(in_pos, x_dims);
int out_index = data_index(out_pos, y_dims);
output_data[out_index] = input_data[in_index];
}
}
}
}
}
void test_transpose(const std::vector<int64_t>& input_shape,
std::vector<int> axis) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize(input_shape);
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("transpose");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("axis", axis);
// create and convert op to MLU model, then run it on MLU
auto op = CreateOp<operators::TransposeOp>(opdesc, &scope);
// transpose_ref must run befor LaunchOp
// otherwise get Cannot access memory
// execute reference implementation and save to output tensor
transpose_ref<float>(op);
out_ref->CopyDataFrom(*out);
LaunchOp(op, {x_var_name}, {out_var_name});
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(MLUBridges, transpose) {
std::vector<int64_t> input_shape = {2, 3, 4, 5};
test_transpose(input_shape, std::vector<int>{0, 1, 3, 2});
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(transpose, kMLU);
USE_SUBGRAPH_BRIDGE(transpose2, kMLU);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.ddNod
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/layout_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace mlu {} // namespace mlu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
layout,
kMLU,
kFloat,
kNHWC,
paddle::lite::kernels::mlu::LayoutNhwcToNchwCompute<PRECISION(kFloat)>,
def_layout_nhwc2nchw_fp32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(
layout,
kMLU,
kFP16,
kNHWC,
paddle::lite::kernels::mlu::LayoutNhwcToNchwCompute<PRECISION(kFP16)>,
def_layout_nhwc2nchw_fp16)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
REGISTER_LITE_KERNEL(
layout,
kMLU,
kFloat,
kNHWC,
paddle::lite::kernels::mlu::LayoutNchwToNhwcCompute<PRECISION(kFloat)>,
def_layout_nchw2nhwc_fp32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
layout,
kMLU,
kFP16,
kNHWC,
paddle::lite::kernels::mlu::LayoutNchwToNhwcCompute<PRECISION(kFP16)>,
def_layout_nchw2nhwc_fp16)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <string>
#include <vector>
#include "lite/backends/x86/math/math_function.h"
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/layout_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace mlu {
template <lite::TargetType Target, typename T>
inline void LayoutTransCompute(const int dim,
const lite::Context<Target>& context,
const lite::Tensor& in,
lite::Tensor* out,
const std::vector<int>& axis) {
switch (dim) {
case 2:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 2> trans2;
trans2(context, in, out, axis);
break;
case 3:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 3> trans3;
trans3(context, in, out, axis);
break;
case 4:
paddle::lite::x86::math::Transpose<lite::TargetType::kX86, T, 4> trans4;
trans4(context, in, out, axis);
break;
default:
CHECK(0) << ("Unsupport dim in mlu layout");
}
}
template <PrecisionType Precision>
class LayoutNchwToNhwcCompute
: public KernelLite<TARGET(kMLU), Precision, DATALAYOUT(kNHWC)> {
public:
using param_t = operators::LayoutParam;
void Run() override {
auto& param = this->template Param<param_t>();
auto* x = param.x;
auto* out = param.y;
out->template mutable_data<float>();
auto x_dims = param.x->dims().size();
auto& context = this->ctx_->template As<X86Context>();
std::vector<int> axis;
switch (x_dims) {
case 2:
axis = {0, 1};
break;
case 3:
axis = {0, 2, 1};
out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[2], out->dims()[1]});
break;
case 4:
axis = {0, 2, 3, 1};
out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[2], out->dims()[3], out->dims()[1]});
break;
default:
CHECK(0) << "Unsupport dim in mlu layout nchw to nhwc";
}
LayoutTransCompute<lite::TargetType::kX86, float>(
x_dims, context, *x, out, axis);
}
std::string doc() const override {
return "Mlu layout transform nchw to nhwc";
}
};
template <PrecisionType Precision>
class LayoutNhwcToNchwCompute
: public KernelLite<TARGET(kMLU), Precision, DATALAYOUT(kNHWC)> {
public:
using param_t = operators::LayoutParam;
void Run() override {
auto& param = this->template Param<param_t>();
auto* x = param.x;
auto* out = param.y;
out->template mutable_data<float>();
auto x_dims = param.x->dims().size();
auto& context = this->ctx_->template As<X86Context>();
std::vector<int> axis;
switch (x_dims) {
case 2:
axis = {0, 1};
break;
case 3:
axis = {0, 2, 1};
out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[2], out->dims()[1]});
break;
case 4:
axis = {0, 3, 1, 2};
out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[3], out->dims()[1], out->dims()[2]});
break;
default:
CHECK(0) << "Unsupport dim in mlu layout nhwc to nchw";
}
LayoutTransCompute<lite::TargetType::kX86, float>(
x_dims, context, *x, out, axis);
}
std::string doc() const override {
return "Mlu layout transform nhwc to nchw";
}
};
} // namespace mlu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -97,9 +97,11 @@ class SubgraphEngine : public subgraph::Engine {
for (auto& inst : origin_program_) {
auto op = inst.op();
CHECK(op);
op->CheckShape();
op->InferShape();
std::string op_type = op->op_info()->Type();
op->CheckShape();
if (op_type != "concat") {
op->InferShape();
}
if (!bridges.Exists(op_type, TARGET(kMLU))) {
LOG(INFO) << "MLU bridges doesn't support op_type: " << op_type;
return subgraph::FAILED;
......
......@@ -63,6 +63,8 @@ add_kernel(sequence_topk_avg_pooling_compute_x86 X86 basic SRCS sequence_topk_av
add_kernel(search_fc_compute_x86 X86 basic SRCS search_fc_compute.cc DEPS ${lite_kernel_deps} search_fc)
add_kernel(matmul_compute_x86 X86 basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} blas)
add_kernel(yolo_box_compute_x86 X86 basic SRCS yolo_box_compute.cc DEPS ${lite_kernel_deps})
add_kernel(interpolate_compute_x86 X86 basic SRCS interpolate_compute.cc DEPS ${lite_kernel_deps})
lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86)
lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86)
......@@ -101,3 +103,7 @@ lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS
#lite_cc_test(test_attention_padding_mask_compute_x86 SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_x86)
lite_cc_test(test_sequence_arithmetic_compute_x86 SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_x86)
lite_cc_test(test_leaky_relu_compute_x86 SRCS leaky_relu_compute_test.cc DEPS activation_compute_x86)
lite_cc_test(test_yolo_box_compute_x86 SRCS yolo_box_compute_test.cc DEPS
yolo_box_compute_x86)
lite_cc_test(test_nearest_interp_comute_x86 SRCS interpolate_compute_test.cc
DEPS interpolate_compute_x86)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/interpolate_compute.h"
REGISTER_LITE_KERNEL(nearest_interp,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::InterpolateCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("OutSize",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindInput("SizeTensor",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/interpolate_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
inline void nearest_interp(const float* src,
int w_in,
int h_in,
float* dst,
int w_out,
int h_out,
bool with_align) {
float scale_w_new = (with_align)
? (static_cast<float>(w_in - 1) / (w_out - 1))
: (static_cast<float>(w_in) / (w_out));
float scale_h_new = (with_align)
? (static_cast<float>(h_in - 1) / (h_out - 1))
: (static_cast<float>(h_in) / (h_out));
if (with_align) {
for (int h = 0; h < h_out; ++h) {
float* dst_p = dst + h * w_out;
int near_y = static_cast<int>(scale_h_new * h + 0.5);
for (int w = 0; w < w_out; ++w) {
int near_x = static_cast<int>(scale_w_new * w + 0.5);
*dst_p++ = src[near_y * w_in + near_x];
}
}
} else {
for (int h = 0; h < h_out; ++h) {
float* dst_p = dst + h * w_out;
int near_y = static_cast<int>(scale_h_new * h);
for (int w = 0; w < w_out; ++w) {
int near_x = static_cast<int>(scale_w_new * w);
*dst_p++ = src[near_y * w_in + near_x];
}
}
}
}
inline std::vector<int> get_new_shape(
std::vector<const lite::Tensor*> list_new_shape_tensor) {
std::vector<int> vec_new_shape;
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i];
vec_new_shape.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
}
return vec_new_shape;
}
class InterpolateCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::InterpolateParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
int in_h = param.X->dims()[2];
int in_w = param.X->dims()[3];
if (param.SizeTensor.size() > 0) {
auto new_size = get_new_shape(param.SizeTensor);
param.out_h = new_size[0];
param.out_w = new_size[1];
} else {
auto scale_tensor = param.Scale;
if (scale_tensor != nullptr) {
auto* scale_data = param.Scale->mutable_data<float>();
param.scale = scale_data[0];
}
if (param.scale > 0) {
param.out_h = static_cast<int>(in_h * param.scale);
param.out_w = static_cast<int>(in_w * param.scale);
}
if (param.OutSize != nullptr) {
auto* outsize_data = param.OutSize->mutable_data<float>();
param.out_h = outsize_data[0];
param.out_w = outsize_data[1];
}
}
int num_cout = param.X->dims()[0];
int c_cout = param.X->dims()[1];
param.Out->Resize({num_cout, c_cout, param.out_h, param.out_w});
float* dout = param.Out->mutable_data<float>();
const float* din = param.X->data<float>();
int out_num = param.Out->dims()[0];
int out_c = param.Out->dims()[1];
int count = out_num * out_c;
int out_h = param.Out->dims()[2];
int out_w = param.Out->dims()[3];
int spatial_in = in_h * in_w;
int spatial_out = out_h * out_w;
#pragma omp parallel for
for (int i = 0; i < count; ++i) {
nearest_interp(din + spatial_in * i,
in_w,
in_h,
dout + spatial_out * i,
out_w,
out_h,
param.align_corners);
}
}
virtual ~InterpolateCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/interpolate_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
void NearestInterpRef(lite::Tensor* input,
lite::Tensor* output,
bool with_align) {
int hin = input->dims()[2];
int win = input->dims()[3];
int channels = input->dims()[1];
int num = input->dims()[0];
int hout = output->dims()[2];
int wout = output->dims()[3];
float scale_w = (with_align) ? (static_cast<float>(win - 1) / (wout - 1))
: (static_cast<float>(win) / (wout));
float scale_h = (with_align) ? (static_cast<float>(hin - 1) / (hout - 1))
: (static_cast<float>(hin) / (hout));
const float* src = input->data<float>();
float* dst = output->mutable_data<float>();
int dst_stride_w = 1;
int dst_stride_h = wout;
int dst_stride_c = wout * hout;
int dst_stride_batch = wout * hout * channels;
int src_stride_w = 1;
int src_stride_h = win;
int src_stride_c = win * hin;
int src_stride_batch = win * hin * channels;
for (int n = 0; n < num; ++n) {
for (int c = 0; c < channels; ++c) {
int src_index = n * src_stride_batch + c * src_stride_c;
for (int h = 0; h < hout; ++h) {
for (int w = 0; w < wout; ++w) {
int fw = (with_align) ? static_cast<int>(scale_w * w + 0.5)
: static_cast<int>(scale_w * w);
fw = (fw < 0) ? 0 : fw;
int fh = (with_align) ? static_cast<int>(scale_h * h + 0.5)
: static_cast<int>(scale_h * h);
fh = (fh < 0) ? 0 : fh;
int w_start = static_cast<int>(fw);
int h_start = static_cast<int>(fh);
int dst_index = n * dst_stride_batch + c * dst_stride_c +
h * dst_stride_h + w * dst_stride_w;
dst[dst_index] =
src[src_index + w_start * src_stride_w + h_start * src_stride_h];
}
}
}
}
}
TEST(interpolate_x86, retrive_op) {
auto interpolate =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"nearest_interp");
ASSERT_FALSE(interpolate.empty());
ASSERT_TRUE(interpolate.front());
}
TEST(interpolate_x86, init) {
InterpolateCompute interpolate;
ASSERT_EQ(interpolate.precision(), PRECISION(kFloat));
ASSERT_EQ(interpolate.target(), TARGET(kX86));
}
TEST(interpolate_x86, run_test) {
lite::Tensor X, OutSize, Out, Out_base;
operators::InterpolateParam param;
InterpolateCompute interpolate;
int n = 1, c = 3, in_h = 40, in_w = 40;
int out_h = 80, out_w = 80;
float scale = 2.0;
param.out_h = out_h;
param.out_w = out_w;
param.scale = scale;
param.align_corners = false;
X.Resize({n, c, in_h, in_w});
OutSize.Resize({2});
Out.Resize({n, c, out_h, out_w});
Out_base.Resize({n, c, out_h, out_w});
auto* out_data = Out.mutable_data<float>();
auto* out_base_data = Out_base.mutable_data<float>();
auto* x_data = X.mutable_data<float>();
auto* outsize_data = OutSize.mutable_data<float>();
for (int i = 0; i < X.dims().production(); i++) {
x_data[i] = i + 5.0;
}
outsize_data[0] = out_h;
outsize_data[1] = out_w;
param.X = &X;
param.OutSize = &OutSize;
param.Out = &Out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
interpolate.SetContext(std::move(ctx));
interpolate.SetParam(std::move(param));
interpolate.Run();
NearestInterpRef(&X, &Out_base, false);
for (int i = 0; i < Out.dims().production(); i++) {
LOG(INFO) << out_data[i];
EXPECT_NEAR(out_data[i], out_base_data[i], 1e-5);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(nearest_interp, kX86, kFloat, kNCHW, def);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/yolo_box_compute.h"
REGISTER_LITE_KERNEL(yolo_box,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::YoloBoxCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("ImgSize",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Scores", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/operators/yolo_box_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
T sigmoid(T x) {
return 1.f / (1.f + expf(-x));
}
template <typename T>
void get_yolo_box(T* box,
const T* x,
const int* anchors,
int i,
int j,
int an_idx,
int grid_size,
int input_size,
int index,
int stride,
int img_height,
int img_width) {
box[0] = (i + sigmoid(x[index])) * img_height / grid_size;
box[1] = (j + sigmoid(x[index + stride])) * img_height / grid_size;
box[2] = std::exp(x[index + stride]) * anchors[2 * an_idx] * img_width /
input_size;
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
img_height / input_size;
}
inline int get_entry_index(int batch,
int an_idx,
int hw_idx,
int an_num,
int an_stride,
int stride,
int entry) {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
}
template <typename T>
void calc_detection_box(T* boxes,
T* box,
const int box_idx,
const int img_height,
const int img_width) {
boxes[box_idx] = box[0] - box[2] / 2;
boxes[box_idx + 1] = box[1] - box[3] / 2;
boxes[box_idx + 2] = box[0] + box[2] / 2;
boxes[box_idx + 3] = box[1] + box[3] / 2;
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<float>(0);
boxes[box_idx + 1] =
boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<float>(0);
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1
? boxes[box_idx + 2]
: static_cast<float>(img_width - 1);
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1
? boxes[box_idx + 3]
: static_cast<float>(img_height - 1);
}
template <typename T>
void calc_label_score(T* scores,
const T* input,
const int label_idx,
const int score_idx,
const int class_num,
const T conf,
const int stride) {
for (int i = 0; i < class_num; i++) {
scores[score_idx + i] = conf * sigmoid(input[label_idx + i * stride]);
}
}
class YoloBoxCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::YoloBoxParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
const int n = param.X->dims()[0];
const int h = param.X->dims()[2];
const int w = param.X->dims()[3];
const int b_num = param.Boxes->dims()[1];
const int an_num = param.anchors.size() / 2;
int X_size = param.downsample_ratio * h;
const int stride = h * w;
const int an_stride = (param.class_num + 5) * stride;
auto anchors_data = param.anchors.data();
const float* X_data = param.X->data<float>();
int* ImgSize_data = param.ImgSize->mutable_data<int>();
float* Boxes_data = param.Boxes->mutable_data<float>();
float* Scores_data = param.Scores->mutable_data<float>();
float box[4];
for (int i = 0; i < n; i++) {
int img_height = ImgSize_data[2 * i];
int img_width = ImgSize_data[2 * i + 1];
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
int obj_idx =
get_entry_index(i, j, k * w + l, an_num, an_stride, stride, 4);
float conf = sigmoid(X_data[obj_idx]);
if (conf < param.conf_thresh) {
continue;
}
int box_idx =
get_entry_index(i, j, k * w + l, an_num, an_stride, stride, 0);
get_yolo_box(box,
X_data,
anchors_data,
l,
k,
j,
h,
X_size,
box_idx,
stride,
img_height,
img_width);
box_idx = (i * b_num + j * stride + k * w + l) * 4;
calc_detection_box(Boxes_data, box, box_idx, img_height, img_width);
int label_idx =
get_entry_index(i, j, k * w + l, an_num, an_stride, stride, 5);
int score_idx =
(i * b_num + j * stride + k * w + l) * param.class_num;
calc_label_score(Scores_data,
X_data,
label_idx,
score_idx,
param.class_num,
conf,
stride);
}
}
}
}
}
virtual ~YoloBoxCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/yolo_box_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
namespace test {
float sigmoid_base(float x) { return 1.f / (1.f + expf(-x)); }
void get_yolo_box_base(float* box,
const float* x,
const int* anchors,
int i,
int j,
int an_idx,
int grid_size,
int input_size,
int index,
int stride,
int img_height,
int img_width) {
box[0] = (i + sigmoid_base(x[index])) * img_width / grid_size;
box[1] = (j + sigmoid_base(x[index + stride])) * img_height / grid_size;
box[2] = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
input_size;
box[3] = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] *
img_height / input_size;
}
int get_entry_index_base(int batch,
int an_idx,
int hw_idx,
int an_num,
int an_stride,
int stride,
int entry) {
return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx;
}
void calc_detection_box_base(float* boxes,
float* box,
const int box_idx,
const int img_height,
const int img_width) {
boxes[box_idx] = box[0] - box[2] / 2;
boxes[box_idx + 1] = box[1] - box[3] / 2;
boxes[box_idx + 2] = box[0] + box[2] / 2;
boxes[box_idx + 3] = box[1] + box[3] / 2;
boxes[box_idx] = boxes[box_idx] > 0 ? boxes[box_idx] : static_cast<float>(0);
boxes[box_idx + 1] =
boxes[box_idx + 1] > 0 ? boxes[box_idx + 1] : static_cast<float>(0);
boxes[box_idx + 2] = boxes[box_idx + 2] < img_width - 1
? boxes[box_idx + 2]
: static_cast<float>(img_width - 1);
boxes[box_idx + 3] = boxes[box_idx + 3] < img_height - 1
? boxes[box_idx + 3]
: static_cast<float>(img_height - 1);
}
void calc_label_score_base(float* scores,
const float* input,
const int label_idx,
const int score_idx,
const int class_num,
const float conf,
const int stride) {
for (int i = 0; i < class_num; i++) {
scores[score_idx + i] = conf * sigmoid_base(input[label_idx + i * stride]);
}
}
void RunBaseline(const lite::Tensor* X,
const lite::Tensor* ImgSize,
lite::Tensor* Boxes,
lite::Tensor* Scores,
int class_num,
float conf_thresh,
int downsample_ratio,
std::vector<int> anchors) {
auto* in = X;
auto* imgsize = ImgSize;
const int n = in->dims()[0];
const int h = in->dims()[2];
const int w = in->dims()[3];
const int an_num = anchors.size() / 2;
int in_size = downsample_ratio * h;
int box_num = in->dims()[2] * in->dims()[3] * an_num;
Boxes->Resize({in->dims()[0], box_num, 4});
Scores->Resize({in->dims()[0], box_num, class_num});
auto* boxes = Boxes;
auto* scores = Scores;
const int b_num = boxes->dims()[0];
const int stride = h * w;
const int an_stride = (class_num + 5) * stride;
auto anchors_data = anchors.data();
const float* in_data = in->data<float>();
const int* imgsize_data = imgsize->data<int>();
float* boxes_data = boxes->mutable_data<float>();
float* scores_data = scores->mutable_data<float>();
float box[4];
for (int i = 0; i < n; i++) {
int img_height = imgsize_data[2 * i];
int img_width = imgsize_data[2 * i + 1];
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
int obj_idx = test::get_entry_index_base(
i, j, k * w + l, an_num, an_stride, stride, 4);
float conf = test::sigmoid_base(in_data[obj_idx]);
if (conf < conf_thresh) {
continue;
}
int box_idx = test::get_entry_index_base(
i, j, k * w + l, an_num, an_stride, stride, 0);
test::get_yolo_box_base(box,
in_data,
anchors_data,
l,
k,
j,
h,
in_size,
box_idx,
stride,
img_height,
img_width);
box_idx = (i * b_num + j * stride + k * w + l) * 4;
test::calc_detection_box_base(
boxes_data, box, box_idx, img_height, img_width);
int label_idx = test::get_entry_index_base(
i, j, k * w + l, an_num, an_stride, stride, 5);
int score_idx = (i * b_num + j * stride + k * w + l) * class_num;
test::calc_label_score_base(scores_data,
in_data,
label_idx,
score_idx,
class_num,
conf,
stride);
}
}
}
}
}
} // namespace test
TEST(yolo_box_x86, retrive_op) {
auto yolo_box =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"yolo_box");
ASSERT_FALSE(yolo_box.empty());
ASSERT_TRUE(yolo_box.front());
}
TEST(yolo_box_x86, init) {
YoloBoxCompute<float> yolo_box;
ASSERT_EQ(yolo_box.precision(), PRECISION(kFloat));
ASSERT_EQ(yolo_box.target(), TARGET(kX86));
}
TEST(yolo_box_x86, run_test) {
lite::Tensor X, ImgSize, Boxes, Scores, Boxes_base, Scores_base;
YoloBoxCompute<float> yolo_box;
operators::YoloBoxParam param;
int s = 3, cls = 4;
int n = 1, c = s * (5 + cls), h = 16, w = 16;
param.anchors = {2, 3, 4, 5, 8, 10};
param.downsample_ratio = 2;
param.conf_thresh = 0.5;
param.class_num = cls;
int m = h * w * param.anchors.size() / 2;
X.Resize({n, c, h, w});
ImgSize.Resize({1, 2});
Boxes.Resize({n, m, 4});
Boxes_base.Resize({n, m, 4});
Scores.Resize({n, cls, m});
Scores_base.Resize({n, cls, m});
auto x_data = X.mutable_data<float>();
auto imgsize_data = ImgSize.mutable_data<float>();
auto boxes_data = Boxes.mutable_data<float>();
auto scores_data = Scores.mutable_data<float>();
auto boxes_base_data = Boxes_base.mutable_data<float>();
auto scores_base_data = Scores_base.mutable_data<float>();
for (int i = 0; i < X.dims().production(); i++) {
x_data[i] = static_cast<float>(i);
}
for (int i = 0; i < ImgSize.dims().production(); i++) {
imgsize_data[i] = static_cast<float>(i);
}
test::RunBaseline(&X,
&ImgSize,
&Boxes_base,
&Scores_base,
param.class_num,
param.conf_thresh,
param.downsample_ratio,
param.anchors);
param.X = &X;
param.ImgSize = &ImgSize;
param.Boxes = &Boxes;
param.Scores = &Scores;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
yolo_box.SetContext(std::move(ctx));
yolo_box.SetParam(std::move(param));
yolo_box.Run();
for (int i = 0; i < Boxes.dims().production(); i++) {
EXPECT_NEAR(boxes_data[i], boxes_base_data[i], 1e-5);
}
for (int i = 0; i < Scores.dims().production(); i++) {
EXPECT_NEAR(scores_data[i], scores_base_data[i], 1e-5);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(yolo_box, kX86, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册