提交 26ba385d 编写于 作者: K kechxu 提交者: Kecheng Xu

Prediction: online integrate lane scanning model

上级 a3507e90
......@@ -140,6 +140,9 @@ DEFINE_string(torch_vehicle_cruise_go_file,
DEFINE_string(torch_vehicle_cruise_cutin_file,
"/apollo/modules/prediction/data/cruise_cutin_vehicle_model.pt",
"Vehicle cruise go model file");
DEFINE_string(torch_vehicle_lane_scanning_file,
"/apollo/modules/prediction/data/lane_scanning_vehicle_model.pt",
"Vehicle lane scanning model file");
DEFINE_int32(max_num_obstacles, 300,
"maximal number of obstacles stored in obstacles container.");
DEFINE_double(valid_position_diff_threshold, 0.5,
......
......@@ -88,6 +88,7 @@ DECLARE_string(evaluator_vehicle_mlp_file);
DECLARE_string(torch_vehicle_junction_mlp_file);
DECLARE_string(torch_vehicle_cruise_go_file);
DECLARE_string(torch_vehicle_cruise_cutin_file);
DECLARE_string(torch_vehicle_lane_scanning_file);
DECLARE_string(evaluator_vehicle_rnn_file);
DECLARE_string(evaluator_vehicle_cruise_mlp_file);
DECLARE_int32(max_num_obstacles);
......
......@@ -191,7 +191,14 @@ cc_library(
deps = [
"//modules/prediction/container:container_manager",
"//modules/prediction/evaluator",
],
] + select({
"//tools/platforms:use_gpu": [
"@pytorch",
],
"//conditions:default": [
"@pytorch",
],
}),
)
cpplint()
......@@ -19,6 +19,7 @@
#include <utility>
#include "cyber/common/file.h"
#include "modules/common/proto/pnc_point.pb.h"
#include "modules/prediction/common/feature_output.h"
#include "modules/prediction/common/prediction_gflags.h"
#include "modules/prediction/common/prediction_system_gflags.h"
......@@ -30,9 +31,12 @@ namespace apollo {
namespace prediction {
using apollo::common::adapter::AdapterConfig;
using apollo::common::TrajectoryPoint;
using apollo::cyber::common::GetProtoFromFile;
LaneScanningEvaluator::LaneScanningEvaluator() {}
LaneScanningEvaluator::LaneScanningEvaluator() : device_(torch::kCPU) {
LoadModel();
}
void LaneScanningEvaluator::Evaluate(Obstacle* obstacle_ptr) {
std::vector<Obstacle*> dummy_dynamic_env;
......@@ -77,7 +81,16 @@ void LaneScanningEvaluator::Evaluate(Obstacle* obstacle_ptr,
ADEBUG << "Save extracted features for learning locally.";
return;
}
// TODO(jiacheng): once the model is trained, implement this online part.
std::vector<torch::jit::IValue> torch_inputs;
torch::Tensor torch_input =
torch::zeros({1, static_cast<int>(feature_values.size())});
for (size_t i = 0; i < feature_values.size(); ++i) {
torch_input[0][i] = static_cast<float>(feature_values[i]);
}
torch_inputs.push_back(std::move(torch_input));
ModelInference(torch_inputs, torch_lane_scanning_model_ptr_,
latest_feature_ptr);
}
bool LaneScanningEvaluator::ExtractFeatures(
......@@ -294,8 +307,42 @@ bool LaneScanningEvaluator::ExtractStaticEnvFeatures(
}
}
size_t max_feature_size = LANE_POINTS_SIZE * SINGLE_LANE_FEATURE_SIZE *
MAX_NUM_LANE;
while (feature_values->size() < max_feature_size) {
feature_values->push_back(0.0);
}
return true;
}
void LaneScanningEvaluator::LoadModel() {
// TODO(all) uncomment the following when cuda issue is resolved
// if (torch::cuda::is_available()) {
// ADEBUG << "CUDA is available";
// device_ = torch::Device(torch::kCUDA);
// }
torch::set_num_threads(1);
torch_lane_scanning_model_ptr_ =
torch::jit::load(FLAGS_torch_vehicle_lane_scanning_file, device_);
}
void LaneScanningEvaluator::ModelInference(
const std::vector<torch::jit::IValue>& torch_inputs,
std::shared_ptr<torch::jit::script::Module> torch_model_ptr,
Feature* feature_ptr) {
auto torch_output_tensor = torch_model_ptr->forward(torch_inputs).toTensor();
auto torch_output = torch_output_tensor.accessor<float, 2>();
for (size_t i = 0; i < SHORT_TERM_TRAJECTORY_SIZE; ++i) {
TrajectoryPoint point;
double x = static_cast<double>(torch_output[0][2 * i]);
double y = static_cast<double>(torch_output[0][2 * i + 1]);
point.mutable_path_point()->set_x(x);
point.mutable_path_point()->set_y(y);
feature_ptr->add_short_term_predicted_trajectory_points()
->CopyFrom(point);
}
}
} // namespace prediction
} // namespace apollo
......@@ -20,6 +20,9 @@
#include <string>
#include <vector>
#include "torch/script.h"
#include "torch/torch.h"
#include "modules/prediction/evaluator/evaluator.h"
namespace apollo {
......@@ -67,6 +70,11 @@ class LaneScanningEvaluator : public Evaluator {
std::string GetName() override { return "LANE_SCANNING_EVALUATOR"; }
private:
/**
* @brief Load model from file
*/
void LoadModel();
/**
* @brief Extract the features for obstacles
* @param Obstacle pointer
......@@ -84,11 +92,22 @@ class LaneScanningEvaluator : public Evaluator {
const LaneGraph* lane_graph_ptr,
std::vector<double>* feature_values);
void ModelInference(
const std::vector<torch::jit::IValue>& torch_inputs,
std::shared_ptr<torch::jit::script::Module> torch_model_ptr,
Feature* feature_ptr);
private:
static const size_t OBSTACLE_FEATURE_SIZE = 5 * 9;
static const size_t INTERACTION_FEATURE_SIZE = 8;
static const size_t SINGLE_LANE_FEATURE_SIZE = 4;
static const size_t LANE_POINTS_SIZE = 100; // (100 * 0.2m = 20m)
static const size_t MAX_NUM_LANE = 10;
static const size_t SHORT_TERM_TRAJECTORY_SIZE = 10;
std::shared_ptr<torch::jit::script::Module>
torch_lane_scanning_model_ptr_ = nullptr;
torch::Device device_;
};
} // namespace prediction
......
......@@ -3,6 +3,7 @@ syntax = "proto2";
package apollo.prediction;
import "modules/common/proto/geometry.proto";
import "modules/common/proto/pnc_point.proto";
import "modules/perception/proto/perception_obstacle.proto";
import "modules/prediction/proto/lane_graph.proto";
import "modules/prediction/proto/prediction_point.proto";
......@@ -63,7 +64,7 @@ message ObstaclePriority {
optional Priority priority = 25 [default = NORMAL];
}
// next id = 32
// next id = 33
message Feature {
// Obstacle ID
optional int32 id = 1;
......@@ -107,6 +108,9 @@ message Feature {
// Obstacle ground-truth labels:
repeated PredictionTrajectoryPoint future_trajectory_points = 31;
// Obstacle short-term predicted trajectory points
repeated common.TrajectoryPoint short_term_predicted_trajectory_points = 32;
}
message ObstacleHistory {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册