提交 61f87fd2 编写于 作者: K kechxu 提交者: Calvin Miao

Prediction: fix bug in lane_sequence_predictor

上级 85afbaf2
......@@ -13,6 +13,7 @@ cc_library(
"//modules/common/proto:pnc_point_proto",
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/common:prediction_util",
"//modules/prediction/common:prediction_map",
"//modules/prediction/predictor/sequence:sequence_predictor",
"//modules/prediction/proto:lane_graph_proto",
"@eigen//:eigen",
......
......@@ -18,10 +18,12 @@
#include <string>
#include <utility>
#include <memory>
#include "modules/common/log.h"
#include "modules/prediction/common/prediction_gflags.h"
#include "modules/prediction/common/prediction_util.h"
#include "modules/prediction/common/prediction_map.h"
namespace apollo {
namespace prediction {
......@@ -29,6 +31,7 @@ namespace prediction {
using apollo::common::PathPoint;
using apollo::common::TrajectoryPoint;
using apollo::common::math::KalmanFilter;
using apollo::hdmap::LaneInfo;
void LaneSequencePredictor::Predict(Obstacle* obstacle) {
Clear();
......@@ -88,9 +91,11 @@ void LaneSequencePredictor::Predict(Obstacle* obstacle) {
std::string curr_lane_id = sequence.lane_segment(0).lane_id();
std::vector<TrajectoryPoint> points;
DrawLaneSequenceTrajectoryPoints(obstacle->kf_lane_tracker(curr_lane_id),
sequence, FLAGS_prediction_duration,
FLAGS_prediction_freq, &points);
DrawLaneSequenceTrajectoryPoints(
feature, curr_lane_id,
obstacle->kf_lane_tracker(curr_lane_id),
sequence, FLAGS_prediction_duration,
FLAGS_prediction_freq, &points);
Trajectory trajectory = GenerateTrajectory(points);
trajectory.set_probability(sequence.probability());
......@@ -102,11 +107,25 @@ void LaneSequencePredictor::Predict(Obstacle* obstacle) {
}
void LaneSequencePredictor::DrawLaneSequenceTrajectoryPoints(
const Feature& feature, const std::string& lane_id,
const KalmanFilter<double, 4, 2, 0>& kf, const LaneSequence& sequence,
double total_time, double freq, std::vector<TrajectoryPoint>* points) {
// PredictionMap* map = PredictionMap::instance();
Eigen::Matrix<double, 4, 1> state(kf.GetStateEstimate());
if (!FLAGS_enable_kf_tracking) {
Eigen::Vector2d position(feature.position().x(),
feature.position().y());
PredictionMap* map = PredictionMap::instance();
std::shared_ptr<const LaneInfo> lane_info = map->LaneById(lane_id);
double lane_s = 0.0;
double lane_l = 0.0;
if (map->GetProjection(position, lane_info, &lane_s, &lane_l)) {
state(0, 0) = lane_s;
state(1, 0) = lane_l;
state(2, 0) = feature.speed();
state(3, 0) = feature.acc();
}
}
if (FLAGS_enable_rnn_acc && sequence.has_acceleration()) {
state(3, 0) = sequence.acceleration();
}
......
......@@ -22,9 +22,8 @@
#ifndef MODULES_PREDICTION_PREDICTOR_LANE_SEQUENCE_LANE_SEQUENCE_PREDICTOR_H_
#define MODULES_PREDICTION_PREDICTOR_LANE_SEQUENCE_LANE_SEQUENCE_PREDICTOR_H_
// #include <string>
#include <string>
#include <vector>
// #include "Eigen/Dense"
#include "modules/common/math/kalman_filter.h"
#include "modules/common/proto/pnc_point.pb.h"
......@@ -62,6 +61,7 @@ class LaneSequencePredictor : public SequencePredictor {
* @param A vector of generated trajectory points
*/
void DrawLaneSequenceTrajectoryPoints(
const Feature& feature, const std::string& lane_id,
const common::math::KalmanFilter<double, 4, 2, 0>& kf,
const LaneSequence& sequence, double total_time, double freq,
std::vector<common::TrajectoryPoint>* points);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册