提交 6e70b94d 编写于 作者: K kechxu 提交者: Kecheng Xu

Prediction: refactor offline mode, gflags, feature proto and data for learning

上级 99b708ab
......@@ -68,6 +68,7 @@ cc_library(
"//modules/common/util",
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/proto:offline_features_proto",
"//modules/prediction/proto:prediction_proto",
],
)
......
......@@ -26,15 +26,17 @@ namespace prediction {
Features FeatureOutput::features_;
ListDataForLearning FeatureOutput::list_data_for_learning_;
PredictionObstacles prediction_obstacles_;
std::size_t FeatureOutput::idx_feature_ = 0;
std::size_t FeatureOutput::idx_learning_ = 0;
std::size_t idx_prediction_obstacle_ = 0;
void FeatureOutput::Close() {
ADEBUG << "Close feature output";
if (FLAGS_prediction_offline_mode) {
Write();
if (FLAGS_prediction_offline_mode == 1) {
WriteFeatureProto();
}
if (FLAGS_prediction_offline_dataforlearning) {
if (FLAGS_prediction_offline_mode == 2) {
WriteDataForLearning();
}
Clear();
......@@ -52,7 +54,7 @@ bool FeatureOutput::Ready() {
return true;
}
void FeatureOutput::Insert(const Feature& feature) {
void FeatureOutput::InsertFeatureProto(const Feature& feature) {
features_.add_feature()->CopyFrom(feature);
}
......@@ -70,7 +72,7 @@ void FeatureOutput::InsertDataForLearning(
ADEBUG << "Insert [" << category << "] data for learning";
}
void FeatureOutput::Write() {
void FeatureOutput::WriteFeatureProto() {
if (features_.feature_size() <= 0) {
ADEBUG << "Skip writing empty feature.";
} else {
......
......@@ -20,6 +20,7 @@
#include <string>
#include "modules/prediction/proto/offline_features.pb.h"
#include "modules/prediction/proto/prediction_obstacle.pb.h"
namespace apollo {
namespace prediction {
......@@ -51,7 +52,7 @@ class FeatureOutput {
* @brief Insert a feature
* @param A feature in proto
*/
static void Insert(const Feature& feature);
static void InsertFeatureProto(const Feature& feature);
/**
* @brief Insert a data_for_learning
......@@ -64,7 +65,7 @@ class FeatureOutput {
/**
* @brief Write features to a file
*/
static void Write();
static void WriteFeatureProto();
/**
* @brief Write DataForLearning features to a file
......@@ -88,6 +89,8 @@ class FeatureOutput {
static std::size_t idx_feature_;
static ListDataForLearning list_data_for_learning_;
static std::size_t idx_learning_;
static PredictionObstacles prediction_obstacles_;
static std::size_t idx_prediction_obstacle_;
};
} // namespace prediction
......
......@@ -35,7 +35,7 @@ TEST_F(FeatureOutputTest, insertion) {
Feature feature;
for (int i = 0; i < 3; ++i) {
Feature feature;
FeatureOutput::Insert(feature);
FeatureOutput::InsertFeatureProto(feature);
}
EXPECT_EQ(3, FeatureOutput::Size());
}
......@@ -44,7 +44,7 @@ TEST_F(FeatureOutputTest, clear) {
Feature feature;
for (int i = 0; i < 3; ++i) {
Feature feature;
FeatureOutput::Insert(feature);
FeatureOutput::InsertFeatureProto(feature);
}
FeatureOutput::Clear();
EXPECT_EQ(0, FeatureOutput::Size());
......
......@@ -144,7 +144,7 @@ void MessageProcess::OnPerception(
auto end_time6 = std::chrono::system_clock::now();
// Insert features to FeatureOutput for offline_mode
if (FLAGS_prediction_offline_mode) {
if (FLAGS_prediction_offline_mode == 1) {
for (const int id :
ptr_obstacles_container->curr_frame_predictable_obstacle_ids()) {
Obstacle* obstacle_ptr = ptr_obstacles_container->GetObstacle(id);
......@@ -155,7 +155,7 @@ void MessageProcess::OnPerception(
AERROR << "Obstacle [" << id << "] has no latest feature.";
return;
}
FeatureOutput::Insert(obstacle_ptr->latest_feature());
FeatureOutput::InsertFeatureProto(obstacle_ptr->latest_feature());
ADEBUG << "Insert feature into feature output";
}
// Not doing evaluation on offline mode
......
......@@ -47,9 +47,11 @@ DEFINE_string(
"a list of bag files or directories for offline mode. The items need to be "
"separated by colon ':'. If this value is not set, the prediction module "
"will use the listen to published ros topic mode.");
DEFINE_bool(prediction_offline_mode, false, "Prediction offline mode");
DEFINE_bool(prediction_offline_dataforlearning, false, "Whether to extract "
"the features for offline learning-models training.");
DEFINE_int32(prediction_offline_mode, 0,
"0: online mode, no dump file"
"1: dump feature proto to feature.x.bin"
"2: dump data for learning to datalearn.x.bin"
"3: dump predicted trajectory to predict_obstacles.x.bin");
// Bag replay timestamp gap
DEFINE_double(replay_timestamp_gap, 10.0,
......
......@@ -31,8 +31,7 @@ DECLARE_bool(prediction_test_mode);
DECLARE_double(prediction_test_duration);
DECLARE_string(prediction_offline_bags);
DECLARE_bool(prediction_offline_mode);
DECLARE_bool(prediction_offline_dataforlearning);
DECLARE_int32(prediction_offline_mode);
// Bag replay timestamp gap
DECLARE_double(replay_timestamp_gap);
......
......@@ -4,8 +4,7 @@
--noadjust_velocity_by_obstacle_heading
--noadjust_velocity_by_position_shift
--noenable_kf_tracking
--noprediction_offline_mode
--noprediction_offline_dataforlearning
--prediction_offline_mode=0
--lane_change_dist=10.0
......
......@@ -67,13 +67,13 @@ void ObstaclesContainer::Insert(const ::google::protobuf::Message& message) {
<< timestamp_ << "].";
return;
}
if (FLAGS_prediction_offline_mode) {
if (FLAGS_prediction_offline_mode == 1) {
if (std::fabs(timestamp - timestamp_) > FLAGS_replay_timestamp_gap ||
FeatureOutput::Size() > FLAGS_max_num_dump_feature) {
FeatureOutput::Write();
FeatureOutput::WriteFeatureProto();
}
}
if (FLAGS_prediction_offline_dataforlearning) {
if (FLAGS_prediction_offline_mode == 2) {
if (std::fabs(timestamp - timestamp_) > FLAGS_replay_timestamp_gap ||
FeatureOutput::SizeOfDataForLearning() > FLAGS_max_num_dump_feature) {
FeatureOutput::WriteDataForLearning();
......
......@@ -92,25 +92,31 @@ void CruiseMLPEvaluator::Evaluate(Obstacle* obstacle_ptr) {
continue;
}
if (!FLAGS_prediction_offline_mode) {
Eigen::MatrixXf obs_feature_mat = VectorToMatrixXf(feature_values, 0,
OBSTACLE_FEATURE_SIZE);
Eigen::MatrixXf lane_feature_mat = VectorToMatrixXf(feature_values,
OBSTACLE_FEATURE_SIZE + INTERACTION_FEATURE_SIZE,
static_cast<int>(feature_values.size()), SINGLE_LANE_FEATURE_SIZE,
LANE_POINTS_SIZE);
Eigen::MatrixXf model_output;
if (lane_sequence_ptr->vehicle_on_lane()) {
go_model_ptr_->Run({lane_feature_mat, obs_feature_mat}, &model_output);
} else {
cutin_model_ptr_->Run(
{lane_feature_mat, obs_feature_mat}, &model_output);
}
double probability = model_output(0, 0);
double finish_time = model_output(0, 1);
lane_sequence_ptr->set_probability(probability);
lane_sequence_ptr->set_time_to_lane_center(finish_time);
// Insert features to DataForLearning
if (FLAGS_prediction_offline_mode == 2) {
FeatureOutput::InsertDataForLearning(
*latest_feature_ptr, feature_values, "junction");
ADEBUG << "Save extracted features for learning locally.";
return; // Skip Compute probability for offline mode
}
Eigen::MatrixXf obs_feature_mat = VectorToMatrixXf(feature_values, 0,
OBSTACLE_FEATURE_SIZE);
Eigen::MatrixXf lane_feature_mat = VectorToMatrixXf(feature_values,
OBSTACLE_FEATURE_SIZE + INTERACTION_FEATURE_SIZE,
static_cast<int>(feature_values.size()), SINGLE_LANE_FEATURE_SIZE,
LANE_POINTS_SIZE);
Eigen::MatrixXf model_output;
if (lane_sequence_ptr->vehicle_on_lane()) {
go_model_ptr_->Run({lane_feature_mat, obs_feature_mat}, &model_output);
} else {
cutin_model_ptr_->Run(
{lane_feature_mat, obs_feature_mat}, &model_output);
}
double probability = model_output(0, 0);
double finish_time = model_output(0, 1);
lane_sequence_ptr->set_probability(probability);
lane_sequence_ptr->set_time_to_lane_center(finish_time);
}
}
......@@ -166,14 +172,6 @@ void CruiseMLPEvaluator::ExtractFeatureValues
feature_values->insert(feature_values->end(),
lane_feature_values.begin(),
lane_feature_values.end());
// For offline training, write the extracted features into proto.
if (FLAGS_prediction_offline_mode) {
SaveOfflineFeatures(lane_sequence_ptr, *feature_values);
ADEBUG << "Save cruise mlp features for obstacle ["
<< obstacle_ptr->id() << "] with dim ["
<< feature_values->size() << "]";
}
}
void CruiseMLPEvaluator::SetObstacleFeatureValues(
......
......@@ -75,7 +75,7 @@ void JunctionMLPEvaluator::Evaluate(Obstacle* obstacle_ptr) {
ExtractFeatureValues(obstacle_ptr, &feature_values);
// Insert features to DataForLearning
if (FLAGS_prediction_offline_dataforlearning) {
if (FLAGS_prediction_offline_mode == 2) {
FeatureOutput::InsertDataForLearning(
*latest_feature_ptr, feature_values, "junction");
ADEBUG << "Save extracted features for learning locally.";
......
......@@ -72,7 +72,7 @@ void LaneScanningEvaluator::Evaluate(
std::vector<double> feature_values;
ExtractFeatures(obstacle_ptr, lane_graph_ptr, &feature_values);
std::vector<double> labels = {0.0};
if (FLAGS_prediction_offline_dataforlearning) {
if (FLAGS_prediction_offline_mode == 2) {
FeatureOutput::InsertDataForLearning(*latest_feature_ptr, feature_values,
"cruise");
ADEBUG << "Save extracted features for learning locally.";
......
......@@ -80,6 +80,14 @@ void MLPEvaluator::Evaluate(Obstacle* obstacle_ptr) {
CHECK(lane_sequence_ptr != nullptr);
std::vector<double> feature_values;
ExtractFeatureValues(obstacle_ptr, lane_sequence_ptr, &feature_values);
// Insert features to DataForLearning
if (FLAGS_prediction_offline_mode == 2 &&
!obstacle_ptr->IsNearJunction()) {
FeatureOutput::InsertDataForLearning(
*latest_feature_ptr, feature_values, "mlp");
ADEBUG << "Save extracted features for learning locally.";
return; // Skip Compute probability for offline mode
}
double probability = ComputeProbability(feature_values);
double centripetal_acc_probability =
......@@ -124,10 +132,6 @@ void MLPEvaluator::ExtractFeatureValues(Obstacle* obstacle_ptr,
obstacle_feature_values.end());
feature_values->insert(feature_values->end(), lane_feature_values.begin(),
lane_feature_values.end());
if (FLAGS_prediction_offline_mode && !obstacle_ptr->IsNearJunction()) {
SaveOfflineFeatures(lane_sequence_ptr, *feature_values);
}
}
void MLPEvaluator::SaveOfflineFeatures(
......
......@@ -83,31 +83,6 @@ bool PredictionComponent::Init() {
prediction_writer_ =
node_->CreateWriter<PredictionObstacles>(FLAGS_prediction_topic);
if (FLAGS_prediction_offline_mode) {
if (!FeatureOutput::Ready()) {
AERROR << "Feature output is not ready.";
return false;
}
if (FLAGS_prediction_offline_bags.empty()) {
return true; // use listen to ROS topic mode
}
std::vector<std::string> inputs;
common::util::Split(FLAGS_prediction_offline_bags, ':', &inputs);
for (const auto& input : inputs) {
std::vector<std::string> offline_bags;
GetRecordFileNames(boost::filesystem::path(input), &offline_bags);
std::sort(offline_bags.begin(), offline_bags.end());
AINFO << "For input " << input << ", found " << offline_bags.size()
<< " rosbags to process";
for (std::size_t i = 0; i < offline_bags.size(); ++i) {
AINFO << "\tProcessing: [ " << i << " / " << offline_bags.size()
<< " ]: " << offline_bags[i];
MessageProcess::ProcessOfflineData(offline_bags[i]);
}
}
FeatureOutput::Close();
return false;
}
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册