From 16cff894ad330c89b591fcc8398aad3dcbc5922b Mon Sep 17 00:00:00 2001 From: Aaron Xiao Date: Wed, 2 Aug 2017 16:12:59 -0700 Subject: [PATCH] Simple code clean of prediction module. --- .../prediction/evaluator/evaluator_factory.cc | 2 +- modules/prediction/evaluator/evaluator_factory.h | 4 ++-- .../prediction/evaluator/evaluator_manager.cc | 9 +++------ modules/prediction/evaluator/evaluator_manager.h | 3 +-- .../evaluator/evaluator_manager_test.cc | 12 +++++------- .../evaluator/vehicle/mlp_evaluator.cc | 16 +++++++++------- .../prediction/evaluator/vehicle/mlp_evaluator.h | 10 +++++----- .../evaluator/vehicle/mlp_evaluator_test.cc | 8 ++++---- modules/prediction/prediction.cc | 7 ++----- .../predictor/pedestrian/regional_predictor.cc | 9 +++++---- modules/prediction/predictor/predictor.cc | 13 ++++--------- .../prediction/predictor/predictor_factory.cc | 5 ++--- .../prediction/predictor/predictor_manager.cc | 14 +++++--------- modules/prediction/predictor/predictor_manager.h | 4 +--- .../predictor/predictor_manager_test.cc | 12 +++++------- .../predictor/vehicle/free_move_predictor.cc | 14 +++++++------- .../vehicle/free_move_predictor_test.cc | 2 +- 17 files changed, 62 insertions(+), 82 deletions(-) diff --git a/modules/prediction/evaluator/evaluator_factory.cc b/modules/prediction/evaluator/evaluator_factory.cc index 6fa0988fbc..d1db0d38e8 100644 --- a/modules/prediction/evaluator/evaluator_factory.cc +++ b/modules/prediction/evaluator/evaluator_factory.cc @@ -16,8 +16,8 @@ #include "modules/prediction/evaluator/evaluator_factory.h" -#include "modules/prediction/evaluator/vehicle/mlp_evaluator.h" #include "modules/common/log.h" +#include "modules/prediction/evaluator/vehicle/mlp_evaluator.h" namespace apollo { namespace prediction { diff --git a/modules/prediction/evaluator/evaluator_factory.h b/modules/prediction/evaluator/evaluator_factory.h index 3de3a3e1e1..ff39a785ce 100644 --- a/modules/prediction/evaluator/evaluator_factory.h +++ b/modules/prediction/evaluator/evaluator_factory.h @@ -24,10 +24,10 @@ #include +#include "modules/common/macro.h" +#include "modules/common/util/factory.h" #include "modules/prediction/evaluator/evaluator.h" #include "modules/prediction/proto/prediction_conf.pb.h" -#include "modules/common/util/factory.h" -#include "modules/common/macro.h" /** * @namespace apollo::prediction diff --git a/modules/prediction/evaluator/evaluator_manager.cc b/modules/prediction/evaluator/evaluator_manager.cc index 9efab2a512..c748b0800c 100644 --- a/modules/prediction/evaluator/evaluator_manager.cc +++ b/modules/prediction/evaluator/evaluator_manager.cc @@ -16,10 +16,10 @@ #include "modules/prediction/evaluator/evaluator_manager.h" +#include "modules/common/log.h" #include "modules/prediction/evaluator/vehicle/mlp_evaluator.h" #include "modules/prediction/container/container_manager.h" #include "modules/prediction/container/obstacles/obstacles_container.h" -#include "modules/common/log.h" namespace apollo { namespace prediction { @@ -64,11 +64,8 @@ void EvaluatorManager::Init(const PredictionConf& config) { Evaluator* EvaluatorManager::GetEvaluator( const ObstacleConf::EvaluatorType& type) { - if (evaluators_.find(type) != evaluators_.end()) { - return evaluators_[type].get(); - } else { - return nullptr; - } + auto it = evaluators_.find(type); + return it != evaluators_.end() ? it->second.get() : nullptr; } void EvaluatorManager::Run( diff --git a/modules/prediction/evaluator/evaluator_manager.h b/modules/prediction/evaluator/evaluator_manager.h index 4c6700a0a7..879e06302e 100644 --- a/modules/prediction/evaluator/evaluator_manager.h +++ b/modules/prediction/evaluator/evaluator_manager.h @@ -84,8 +84,7 @@ class EvaluatorManager { void RegisterEvaluators(); private: - std::map> evaluators_; + std::map> evaluators_; ObstacleConf::EvaluatorType vehicle_on_lane_evaluator_ = ObstacleConf::MLP_EVALUATOR; diff --git a/modules/prediction/evaluator/evaluator_manager_test.cc b/modules/prediction/evaluator/evaluator_manager_test.cc index 8dc7bbb82d..a05beb13ba 100644 --- a/modules/prediction/evaluator/evaluator_manager_test.cc +++ b/modules/prediction/evaluator/evaluator_manager_test.cc @@ -14,13 +14,13 @@ * limitations under the License. *****************************************************************************/ -#include +#include "modules/prediction/evaluator/evaluator_manager.h" +#include #include "gtest/gtest.h" -#include "modules/prediction/evaluator/evaluator_manager.h" -#include "modules/prediction/proto/prediction_conf.pb.h" #include "modules/common/util/file.h" +#include "modules/prediction/proto/prediction_conf.pb.h" namespace apollo { namespace prediction { @@ -38,10 +38,8 @@ class EvaluatorManagerTest : public ::testing::Test { TEST_F(EvaluatorManagerTest, GetEvaluators) { std::string conf_file = "modules/prediction/testdata/prediction_conf.pb.txt"; - bool ret_load_conf = ::apollo::common::util::GetProtoFromFile(conf_file, - &conf_); - EXPECT_TRUE(ret_load_conf); - EXPECT_TRUE(conf_.IsInitialized()); + CHECK(apollo::common::util::GetProtoFromFile(conf_file, &conf_)) + << "Failed to load " << conf_file; manager_->Init(conf_); diff --git a/modules/prediction/evaluator/vehicle/mlp_evaluator.cc b/modules/prediction/evaluator/vehicle/mlp_evaluator.cc index 0116075042..9f0e131fcc 100644 --- a/modules/prediction/evaluator/vehicle/mlp_evaluator.cc +++ b/modules/prediction/evaluator/vehicle/mlp_evaluator.cc @@ -14,15 +14,16 @@ * limitations under the License. *****************************************************************************/ +#include "modules/prediction/evaluator/vehicle/mlp_evaluator.h" + #include #include -#include "modules/prediction/evaluator/vehicle/mlp_evaluator.h" -#include "modules/prediction/common/prediction_gflags.h" #include "modules/common/math/math_utils.h" -#include "modules/prediction/common/prediction_util.h" -#include "modules/map/proto/map_lane.pb.h" #include "modules/common/util/file.h" +#include "modules/map/proto/map_lane.pb.h" +#include "modules/prediction/common/prediction_gflags.h" +#include "modules/prediction/common/prediction_util.h" namespace apollo { namespace prediction { @@ -80,11 +81,12 @@ void MLPEvaluator::ExtractFeatureValues(Obstacle* obstacle_ptr, feature_values_.clear(); int id = obstacle_ptr->id(); std::vector obstacle_feature_values; - if (obstacle_feature_values_map_.find(id) == - obstacle_feature_values_map_.end()) { + + auto it = obstacle_feature_values_map_.find(id); + if (it == obstacle_feature_values_map_.end()) { SetObstacleFeatureValues(obstacle_ptr, &obstacle_feature_values); } else { - obstacle_feature_values = obstacle_feature_values_map_[id]; + obstacle_feature_values = it->second; } if (obstacle_feature_values.size() != OBSTACLE_FEATURE_SIZE) { diff --git a/modules/prediction/evaluator/vehicle/mlp_evaluator.h b/modules/prediction/evaluator/vehicle/mlp_evaluator.h index 7085be79cd..f56b7b2854 100644 --- a/modules/prediction/evaluator/vehicle/mlp_evaluator.h +++ b/modules/prediction/evaluator/vehicle/mlp_evaluator.h @@ -17,15 +17,15 @@ #ifndef MODULES_PREDICTION_EVALUATOR_VEHICLE_MLP_EVALUATOR_H_ #define MODULES_PREDICTION_EVALUATOR_VEHICLE_MLP_EVALUATOR_H_ -#include -#include -#include #include +#include +#include +#include -#include "modules/prediction/evaluator/evaluator.h" #include "modules/prediction/container/obstacles/obstacle.h" -#include "modules/prediction/proto/lane_graph.pb.h" +#include "modules/prediction/evaluator/evaluator.h" #include "modules/prediction/proto/fnn_vehicle_model.pb.h" +#include "modules/prediction/proto/lane_graph.pb.h" namespace apollo { namespace prediction { diff --git a/modules/prediction/evaluator/vehicle/mlp_evaluator_test.cc b/modules/prediction/evaluator/vehicle/mlp_evaluator_test.cc index 0fded2faf7..099a95ae9b 100644 --- a/modules/prediction/evaluator/vehicle/mlp_evaluator_test.cc +++ b/modules/prediction/evaluator/vehicle/mlp_evaluator_test.cc @@ -21,21 +21,21 @@ #include "gtest/gtest.h" -#include "modules/perception/proto/perception_obstacle.pb.h" -#include "modules/prediction/common/prediction_gflags.h" #include "modules/common/util/file.h" +#include "modules/prediction/common/prediction_gflags.h" #include "modules/prediction/container/obstacles/obstacle.h" #include "modules/prediction/container/obstacles/obstacles_container.h" +#include "modules/perception/proto/perception_obstacle.pb.h" namespace apollo { namespace prediction { class MLPEvaluatorTest : public ::testing::Test { public: - virtual void SetUp() { + void SetUp() override { std::string file = "modules/prediction/testdata/single_perception_vehicle_onlane.pb.txt"; - apollo::common::util::GetProtoFromFile(file, &perception_obstacles_); + CHECK(apollo::common::util::GetProtoFromFile(file, &perception_obstacles_)); FLAGS_map_file = "modules/prediction/testdata/kml_map.bin"; } protected: diff --git a/modules/prediction/prediction.cc b/modules/prediction/prediction.cc index 27df419743..2477009947 100644 --- a/modules/prediction/prediction.cc +++ b/modules/prediction/prediction.cc @@ -69,11 +69,8 @@ Status Prediction::Init() { EvaluatorManager::instance()->Init(prediction_conf_); PredictorManager::instance()->Init(prediction_conf_); - CHECK(AdapterManager::GetLocalization()) - << "Localization is not ready."; - - CHECK(AdapterManager::GetPerceptionObstacles()) - << "Perception is not ready."; + CHECK(AdapterManager::GetLocalization()) << "Localization is not ready."; + CHECK(AdapterManager::GetPerceptionObstacles()) << "Perception is not ready."; // Set perception obstacle callback function AdapterManager::SetPerceptionObstaclesCallback(&Prediction::OnPerception, diff --git a/modules/prediction/predictor/pedestrian/regional_predictor.cc b/modules/prediction/predictor/pedestrian/regional_predictor.cc index 283592a597..11b0e149fc 100644 --- a/modules/prediction/predictor/pedestrian/regional_predictor.cc +++ b/modules/prediction/predictor/pedestrian/regional_predictor.cc @@ -14,11 +14,12 @@ * limitations under the License. *****************************************************************************/ -#include +#include "modules/prediction/predictor/pedestrian/regional_predictor.h" + #include +#include #include -#include "modules/prediction/predictor/pedestrian/regional_predictor.h" #include "modules/prediction/common/prediction_util.h" #include "modules/prediction/common/prediction_gflags.h" #include "modules/common/math/kalman_filter.h" @@ -154,8 +155,8 @@ void RegionalPredictor::GenerateMovingTrajectory( std::vector right_points; DrawMovingTrajectory(position, velocity, acc, - obstacle->kf_pedestrian_tracker(), total_time, - &left_points, &right_points); + obstacle->kf_pedestrian_tracker(), total_time, + &left_points, &right_points); int start_index = GetTrajectorySize(); Trajectory left_trajectory; diff --git a/modules/prediction/predictor/predictor.cc b/modules/prediction/predictor/predictor.cc index 45bc55b660..d6943ba2c7 100644 --- a/modules/prediction/predictor/predictor.cc +++ b/modules/prediction/predictor/predictor.cc @@ -32,19 +32,14 @@ int Predictor::GetTrajectorySize() { void Predictor::GenerateTrajectory( const std::vector<::apollo::common::TrajectoryPoint>& points, Trajectory* trajectory) { - if (points.size() <= 0) { - return; - } - - for (const auto& point : points) { - trajectory->add_trajectory_point()->CopyFrom(point); - } + trajectory->mutable_trajectory_point()->MergeFrom( + {points.begin(), points.end()}); } void Predictor::SetEqualProbability(double probability, int start_index) { int num = GetTrajectorySize(); - if (start_index >= 0 && num > 0 && num > start_index) { - probability = probability / static_cast(num - start_index); + if (start_index >= 0 && num > start_index) { + probability /= static_cast(num - start_index); for (int i = start_index; i < num; ++i) { prediction_obstacle_.mutable_trajectory(i)->set_probability(probability); } diff --git a/modules/prediction/predictor/predictor_factory.cc b/modules/prediction/predictor/predictor_factory.cc index 5c628feecc..a0bb46fadf 100644 --- a/modules/prediction/predictor/predictor_factory.cc +++ b/modules/prediction/predictor/predictor_factory.cc @@ -16,10 +16,9 @@ #include "modules/prediction/predictor/predictor_factory.h" -#include "modules/prediction/predictor/vehicle/lane_sequence_predictor.h" -#include "modules/prediction/predictor/vehicle/free_move_predictor.h" - #include "modules/common/log.h" +#include "modules/prediction/predictor/vehicle/free_move_predictor.h" +#include "modules/prediction/predictor/vehicle/lane_sequence_predictor.h" namespace apollo { namespace prediction { diff --git a/modules/prediction/predictor/predictor_manager.cc b/modules/prediction/predictor/predictor_manager.cc index 2a62cf6c29..ec9d32fa63 100644 --- a/modules/prediction/predictor/predictor_manager.cc +++ b/modules/prediction/predictor/predictor_manager.cc @@ -14,10 +14,10 @@ * limitations under the License. *****************************************************************************/ -#include - #include "modules/prediction/predictor/predictor_manager.h" +#include + #include "modules/prediction/predictor/vehicle/lane_sequence_predictor.h" #include "modules/prediction/predictor/vehicle/free_move_predictor.h" #include "modules/prediction/predictor/pedestrian/regional_predictor.h" @@ -80,15 +80,11 @@ void PredictorManager::Init(const PredictionConf& config) { Predictor* PredictorManager::GetPredictor( const ObstacleConf::PredictorType& type) { - if (predictors_.find(type) != predictors_.end()) { - return predictors_[type].get(); - } else { - return nullptr; - } + auto it = predictors_.find(type); + return it != predictors_.end() ? it->second.get() : nullptr; } -void PredictorManager::Run( - const PerceptionObstacles& perception_obstacles) { +void PredictorManager::Run(const PerceptionObstacles& perception_obstacles) { prediction_obstacles_.Clear(); ObstaclesContainer *container = dynamic_cast( ContainerManager::instance()->GetContainer( diff --git a/modules/prediction/predictor/predictor_manager.h b/modules/prediction/predictor/predictor_manager.h index 19bb8195f2..c247078630 100644 --- a/modules/prediction/predictor/predictor_manager.h +++ b/modules/prediction/predictor/predictor_manager.h @@ -22,7 +22,6 @@ #ifndef MODULES_PREDICTION_PREDICTOR_PREDICTOR_MANAGER_H_ #define MODULES_PREDICTION_PREDICTOR_PREDICTOR_MANAGER_H_ -#include #include #include @@ -92,8 +91,7 @@ class PredictorManager { void RegisterPredictors(); private: - std::map> predictors_; + std::map> predictors_; ObstacleConf::PredictorType vehicle_on_lane_predictor_ = ObstacleConf::LANE_SEQUENCE_PREDICTOR; diff --git a/modules/prediction/predictor/predictor_manager_test.cc b/modules/prediction/predictor/predictor_manager_test.cc index e85bf9eb69..abcefb962d 100644 --- a/modules/prediction/predictor/predictor_manager_test.cc +++ b/modules/prediction/predictor/predictor_manager_test.cc @@ -14,11 +14,11 @@ * limitations under the License. *****************************************************************************/ +#include "modules/prediction/predictor/predictor_manager.h" + #include #include "gtest/gtest.h" - -#include "modules/prediction/predictor/predictor_manager.h" #include "modules/prediction/proto/prediction_conf.pb.h" #include "modules/common/util/file.h" @@ -27,7 +27,7 @@ namespace prediction { class PredictorManagerTest : public ::testing::Test { public: - virtual void SetUp() { + void SetUp() override { manager_ = PredictorManager::instance(); } @@ -38,10 +38,8 @@ class PredictorManagerTest : public ::testing::Test { TEST_F(PredictorManagerTest, GetPredictor) { std::string conf_file = "modules/prediction/testdata/prediction_conf.pb.txt"; - bool ret_load_conf = ::apollo::common::util::GetProtoFromFile(conf_file, - &conf_); - EXPECT_TRUE(ret_load_conf); - EXPECT_TRUE(conf_.IsInitialized()); + CHECK(apollo::common::util::GetProtoFromFile(conf_file, &conf_)) + << "Failed to load " << conf_file; manager_->Init(conf_); diff --git a/modules/prediction/predictor/vehicle/free_move_predictor.cc b/modules/prediction/predictor/vehicle/free_move_predictor.cc index 2cae6050cd..b116d00676 100644 --- a/modules/prediction/predictor/vehicle/free_move_predictor.cc +++ b/modules/prediction/predictor/vehicle/free_move_predictor.cc @@ -16,10 +16,10 @@ #include "modules/prediction/predictor/vehicle/free_move_predictor.h" -#include #include #include #include +#include #include "Eigen/Dense" #include "modules/prediction/common/prediction_gflags.h" @@ -114,13 +114,13 @@ void FreeMovePredictor::DrawFreeMoveTrajectoryPoints( for (size_t i = 0; i < static_cast(total_time / freq); ++i) { double speed = std::hypot(v_x, v_y); if (speed <= std::numeric_limits::epsilon()) { - speed = 0.0; - v_x = 0.0; - v_y = 0.0; - acc_x = 0.0; - acc_y = 0.0; + speed = 0.0; + v_x = 0.0; + v_y = 0.0; + acc_x = 0.0; + acc_y = 0.0; } else if (speed > FLAGS_max_speed) { - speed = FLAGS_max_speed; + speed = FLAGS_max_speed; } // update theta diff --git a/modules/prediction/predictor/vehicle/free_move_predictor_test.cc b/modules/prediction/predictor/vehicle/free_move_predictor_test.cc index 787749bf20..5c1c518c58 100644 --- a/modules/prediction/predictor/vehicle/free_move_predictor_test.cc +++ b/modules/prediction/predictor/vehicle/free_move_predictor_test.cc @@ -36,7 +36,7 @@ class FreeMovePredictorTest : public ::testing::Test { virtual void SetUp() { std::string file = "modules/prediction/testdata/single_perception_vehicle_offlane.pb.txt"; - apollo::common::util::GetProtoFromFile(file, &perception_obstacles_); + CHECK(apollo::common::util::GetProtoFromFile(file, &perception_obstacles_)); FLAGS_map_file = "modules/prediction/testdata/kml_map.bin"; FLAGS_p_var = 0.1; FLAGS_q_var = 0.01; -- GitLab