提交 8ce0d74e 编写于 作者: D Dontmovedad 提交者: Qi Luo

Obstacel Detector: Fix inference input dimension

上级 d8926d46
......@@ -32,33 +32,35 @@ LibtorchObstacleDetection::LibtorchObstacleDetection() : device_(torch::kCPU) {
}
bool LibtorchObstacleDetection::Evaluate(
const std::vector<std::vector<double>>& imageFrame) {
const std::vector<std::vector<std::vector<double>>>& imageFrame) {
// Sanity checks.
omp_set_num_threads(1);
if (imageFrame.size() == 0) {
AINFO << "Got no channel in image frame!";
AERROR << "Got no channel in image frame!";
return false;
}
if (imageFrame[0].size() == 0) {
AINFO << "Got no image frame in channel 0!";
AERROR << "Got no image frame in channel 0!";
return false;
}
if (imageFrame[0].size() != 72000) {
AINFO << "imageFrame[0].size() = " << imageFrame[0].size() << ", skiping!";
if (imageFrame[0].size() != 369664) {
AERROR << "imageFrame[0].size() = " << imageFrame[0].size() << ", skiping!";
return false;
}
// image imput size is 1920 * 1080 = 2073600
torch::Tensor image_tensor = torch::empty(32 * 3 * 3 * 3);
// image imput size is 608 * 608 = 369664
torch::Tensor image_tensor = torch::empty(1 * 3 * 608 * 608);
float* data = image_tensor.data_ptr<float>();
for (const auto& channel : imageFrame) {
for (const auto& i : channel) {
*data++ = static_cast<float>(i) / 32767.0;
for (const auto& j : i) {
*data++ = static_cast<float>(j) / 255.0;
}
}
}
torch::Tensor torch_input = torch::from_blob(image_tensor.data_ptr<float>(),
{32, 3, 3, 3, 3});
{1, 3, 608, 608});
std::vector<torch::jit::IValue> torch_inputs;
torch_inputs.push_back(torch_input.to(device_));
......
......@@ -20,6 +20,10 @@
#include "torch/script.h"
#include "torch/torch.h"
#include "opencv2/core/core.hpp"
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"
namespace apollo {
namespace perception {
namespace inference {
......@@ -30,7 +34,7 @@ class LibtorchObstacleDetection {
~LibtorchObstacleDetection() = default;
bool Evaluate(const std::vector<std::vector<double>>& imageFrame);
bool Evaluate(const std::vector<std::vector<std::vector<double>>>& imageFrame);
private:
void LoadModel();
......
......@@ -32,8 +32,7 @@ class LibtorchObstacleDetectionTest : public ::testing::Test {
};
TEST_F(LibtorchObstacleDetectionTest, is_) {
std::vector<std::vector<double>> imageFrame(4,
(std::vector<double> (2073600, 0.01)));
std::vector<std::vector<std::vector<double>>>imageFrame(3, std::vector<std::vector<double>>(608, std::vector<double>(608, 0.01)));
bool result = obstacle_detection_.Evaluate(imageFrame);
EXPECT_EQ(result, false);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册