提交 4498313d 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] add yolo_box bridge (#3319)

上级 7603c043
......@@ -25,6 +25,7 @@ lite_cc_library(subgraph_bridge_layer_norm_op_xpu SRCS layer_norm_op.cc DEPS ${x
lite_cc_library(subgraph_bridge_dropout_op_xpu SRCS dropout_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_matmul_op_xpu SRCS matmul_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_cast_op_xpu SRCS cast_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_yolo_box_op_xpu SRCS yolo_box_op.cc DEPS ${xpu_subgraph_bridge_deps})
set(xpu_subgraph_bridges
subgraph_bridge_registry
......@@ -48,6 +49,7 @@ set(xpu_subgraph_bridges
subgraph_bridge_dropout_op_xpu
subgraph_bridge_matmul_op_xpu
subgraph_bridge_cast_op_xpu
subgraph_bridge_yolo_box_op_xpu
CACHE INTERNAL "xpu_subgraph_bridges")
message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}")
......@@ -37,3 +37,4 @@ USE_SUBGRAPH_BRIDGE(gelu, kXPU);
USE_SUBGRAPH_BRIDGE(dropout, kXPU);
USE_SUBGRAPH_BRIDGE(matmul, kXPU);
USE_SUBGRAPH_BRIDGE(cast, kXPU);
USE_SUBGRAPH_BRIDGE(yolo_box, kXPU);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/graph.h"
#include "lite/kernels/xpu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace xpu {
int YoloBoxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindTensor(x_name);
auto img_size_name = op_info->Input("ImgSize").front();
auto img_size = scope->FindTensor(img_size_name);
auto boxes_name = op_info->Output("Boxes").front();
auto scores_name = op_info->Output("Scores").front();
auto anchors = op_info->GetAttr<std::vector<int>>("anchors");
auto class_num = op_info->GetAttr<int>("class_num");
auto conf_thresh = op_info->GetAttr<float>("conf_thresh");
auto downsample_ratio = op_info->GetAttr<int>("downsample_ratio");
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// ImgSize node
std::shared_ptr<Node> img_size_node = nullptr;
if (graph->Has(img_size_name)) {
img_size_node = graph->Get(img_size_name);
} else {
img_size_node = graph->Add(img_size_name, *img_size);
}
// Softmax node
auto yolo_box_data =
graph->builder_.CreateYoloBox(*x_node->data(),
*img_size_node->data(),
CvtShape<xtcl::Integer>(anchors),
class_num,
conf_thresh,
downsample_ratio);
graph->Add(boxes_name, graph->builder_.GetField(yolo_box_data, 0));
graph->Add(scores_name, graph->builder_.GetField(yolo_box_data, 1));
return SUCCESS;
}
} // namespace xpu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(yolo_box,
kXPU,
paddle::lite::subgraph::xpu::YoloBoxConverter);
......@@ -34,7 +34,7 @@ int SubgraphEngine::BuildDeviceProgram() {
subgraph::xpu::Graph graph;
const auto& bridges = subgraph::Registry::Instance();
for (auto& inst : origin_program_) {
auto op = inst.op();
auto op = const_cast<OpLite*>(inst.op());
CHECK(op);
op->CheckShape();
op->InferShape();
......@@ -43,10 +43,8 @@ int SubgraphEngine::BuildDeviceProgram() {
return subgraph::FAILED;
}
auto kernel = inst.kernel();
status |=
bridges.Select(op_type, TARGET(kXPU))(reinterpret_cast<void*>(&graph),
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
status |= bridges.Select(op_type, TARGET(kXPU))(
reinterpret_cast<void*>(&graph), op, const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) {
return subgraph::FAILED;
}
......
......@@ -228,14 +228,14 @@ class YoloBoxComputeTester : public arena::TestCase {
}
};
void test_yolobox(Place place) {
for (int class_num : {1, 2, 3, 4}) {
for (float conf_thresh : {0.01, 0.2, 0.7}) {
void TestYoloBox(Place place, float abs_error) {
for (int class_num : {1, 4}) {
for (float conf_thresh : {0.01, 0.2}) {
for (int downsample_ratio : {16, 32}) {
std::vector<int> anchor({10, 13, 16, 30});
std::vector<int> anchor{10, 13, 16, 30, 33, 30};
std::unique_ptr<arena::TestCase> tester(new YoloBoxComputeTester(
place, "def", anchor, class_num, conf_thresh, downsample_ratio));
arena::Arena arena(std::move(tester), place, 2e-5);
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
......@@ -243,13 +243,17 @@ void test_yolobox(Place place) {
}
TEST(YoloBox, precision) {
// #ifdef LITE_WITH_X86
// Place place(TARGET(kX86));
// #endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_yolobox(place);
float abs_error = 2e-5;
Place place;
#if defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
TestYoloBox(place, abs_error);
}
} // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册