提交 16cff894 编写于 作者: A Aaron Xiao 提交者: Jiangtao Hu

Simple code clean of prediction module.

上级 516fb93d
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include "modules/prediction/evaluator/evaluator_factory.h" #include "modules/prediction/evaluator/evaluator_factory.h"
#include "modules/prediction/evaluator/vehicle/mlp_evaluator.h"
#include "modules/common/log.h" #include "modules/common/log.h"
#include "modules/prediction/evaluator/vehicle/mlp_evaluator.h"
namespace apollo { namespace apollo {
namespace prediction { namespace prediction {
......
...@@ -24,10 +24,10 @@ ...@@ -24,10 +24,10 @@
#include <memory> #include <memory>
#include "modules/common/macro.h"
#include "modules/common/util/factory.h"
#include "modules/prediction/evaluator/evaluator.h" #include "modules/prediction/evaluator/evaluator.h"
#include "modules/prediction/proto/prediction_conf.pb.h" #include "modules/prediction/proto/prediction_conf.pb.h"
#include "modules/common/util/factory.h"
#include "modules/common/macro.h"
/** /**
* @namespace apollo::prediction * @namespace apollo::prediction
......
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
#include "modules/prediction/evaluator/evaluator_manager.h" #include "modules/prediction/evaluator/evaluator_manager.h"
#include "modules/common/log.h"
#include "modules/prediction/evaluator/vehicle/mlp_evaluator.h" #include "modules/prediction/evaluator/vehicle/mlp_evaluator.h"
#include "modules/prediction/container/container_manager.h" #include "modules/prediction/container/container_manager.h"
#include "modules/prediction/container/obstacles/obstacles_container.h" #include "modules/prediction/container/obstacles/obstacles_container.h"
#include "modules/common/log.h"
namespace apollo { namespace apollo {
namespace prediction { namespace prediction {
...@@ -64,11 +64,8 @@ void EvaluatorManager::Init(const PredictionConf& config) { ...@@ -64,11 +64,8 @@ void EvaluatorManager::Init(const PredictionConf& config) {
Evaluator* EvaluatorManager::GetEvaluator( Evaluator* EvaluatorManager::GetEvaluator(
const ObstacleConf::EvaluatorType& type) { const ObstacleConf::EvaluatorType& type) {
if (evaluators_.find(type) != evaluators_.end()) { auto it = evaluators_.find(type);
return evaluators_[type].get(); return it != evaluators_.end() ? it->second.get() : nullptr;
} else {
return nullptr;
}
} }
void EvaluatorManager::Run( void EvaluatorManager::Run(
......
...@@ -84,8 +84,7 @@ class EvaluatorManager { ...@@ -84,8 +84,7 @@ class EvaluatorManager {
void RegisterEvaluators(); void RegisterEvaluators();
private: private:
std::map<ObstacleConf::EvaluatorType, std::map<ObstacleConf::EvaluatorType, std::unique_ptr<Evaluator>> evaluators_;
std::unique_ptr<Evaluator>> evaluators_;
ObstacleConf::EvaluatorType vehicle_on_lane_evaluator_ = ObstacleConf::EvaluatorType vehicle_on_lane_evaluator_ =
ObstacleConf::MLP_EVALUATOR; ObstacleConf::MLP_EVALUATOR;
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
* limitations under the License. * limitations under the License.
*****************************************************************************/ *****************************************************************************/
#include <string> #include "modules/prediction/evaluator/evaluator_manager.h"
#include <string>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "modules/prediction/evaluator/evaluator_manager.h"
#include "modules/prediction/proto/prediction_conf.pb.h"
#include "modules/common/util/file.h" #include "modules/common/util/file.h"
#include "modules/prediction/proto/prediction_conf.pb.h"
namespace apollo { namespace apollo {
namespace prediction { namespace prediction {
...@@ -38,10 +38,8 @@ class EvaluatorManagerTest : public ::testing::Test { ...@@ -38,10 +38,8 @@ class EvaluatorManagerTest : public ::testing::Test {
TEST_F(EvaluatorManagerTest, GetEvaluators) { TEST_F(EvaluatorManagerTest, GetEvaluators) {
std::string conf_file = "modules/prediction/testdata/prediction_conf.pb.txt"; std::string conf_file = "modules/prediction/testdata/prediction_conf.pb.txt";
bool ret_load_conf = ::apollo::common::util::GetProtoFromFile(conf_file, CHECK(apollo::common::util::GetProtoFromFile(conf_file, &conf_))
&conf_); << "Failed to load " << conf_file;
EXPECT_TRUE(ret_load_conf);
EXPECT_TRUE(conf_.IsInitialized());
manager_->Init(conf_); manager_->Init(conf_);
......
...@@ -14,15 +14,16 @@ ...@@ -14,15 +14,16 @@
* limitations under the License. * limitations under the License.
*****************************************************************************/ *****************************************************************************/
#include "modules/prediction/evaluator/vehicle/mlp_evaluator.h"
#include <cmath> #include <cmath>
#include <numeric> #include <numeric>
#include "modules/prediction/evaluator/vehicle/mlp_evaluator.h"
#include "modules/prediction/common/prediction_gflags.h"
#include "modules/common/math/math_utils.h" #include "modules/common/math/math_utils.h"
#include "modules/prediction/common/prediction_util.h"
#include "modules/map/proto/map_lane.pb.h"
#include "modules/common/util/file.h" #include "modules/common/util/file.h"
#include "modules/map/proto/map_lane.pb.h"
#include "modules/prediction/common/prediction_gflags.h"
#include "modules/prediction/common/prediction_util.h"
namespace apollo { namespace apollo {
namespace prediction { namespace prediction {
...@@ -80,11 +81,12 @@ void MLPEvaluator::ExtractFeatureValues(Obstacle* obstacle_ptr, ...@@ -80,11 +81,12 @@ void MLPEvaluator::ExtractFeatureValues(Obstacle* obstacle_ptr,
feature_values_.clear(); feature_values_.clear();
int id = obstacle_ptr->id(); int id = obstacle_ptr->id();
std::vector<double> obstacle_feature_values; std::vector<double> obstacle_feature_values;
if (obstacle_feature_values_map_.find(id) ==
obstacle_feature_values_map_.end()) { auto it = obstacle_feature_values_map_.find(id);
if (it == obstacle_feature_values_map_.end()) {
SetObstacleFeatureValues(obstacle_ptr, &obstacle_feature_values); SetObstacleFeatureValues(obstacle_ptr, &obstacle_feature_values);
} else { } else {
obstacle_feature_values = obstacle_feature_values_map_[id]; obstacle_feature_values = it->second;
} }
if (obstacle_feature_values.size() != OBSTACLE_FEATURE_SIZE) { if (obstacle_feature_values.size() != OBSTACLE_FEATURE_SIZE) {
......
...@@ -17,15 +17,15 @@ ...@@ -17,15 +17,15 @@
#ifndef MODULES_PREDICTION_EVALUATOR_VEHICLE_MLP_EVALUATOR_H_ #ifndef MODULES_PREDICTION_EVALUATOR_VEHICLE_MLP_EVALUATOR_H_
#define MODULES_PREDICTION_EVALUATOR_VEHICLE_MLP_EVALUATOR_H_ #define MODULES_PREDICTION_EVALUATOR_VEHICLE_MLP_EVALUATOR_H_
#include <vector>
#include <unordered_map>
#include <string>
#include <memory> #include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "modules/prediction/evaluator/evaluator.h"
#include "modules/prediction/container/obstacles/obstacle.h" #include "modules/prediction/container/obstacles/obstacle.h"
#include "modules/prediction/proto/lane_graph.pb.h" #include "modules/prediction/evaluator/evaluator.h"
#include "modules/prediction/proto/fnn_vehicle_model.pb.h" #include "modules/prediction/proto/fnn_vehicle_model.pb.h"
#include "modules/prediction/proto/lane_graph.pb.h"
namespace apollo { namespace apollo {
namespace prediction { namespace prediction {
......
...@@ -21,21 +21,21 @@ ...@@ -21,21 +21,21 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "modules/perception/proto/perception_obstacle.pb.h"
#include "modules/prediction/common/prediction_gflags.h"
#include "modules/common/util/file.h" #include "modules/common/util/file.h"
#include "modules/prediction/common/prediction_gflags.h"
#include "modules/prediction/container/obstacles/obstacle.h" #include "modules/prediction/container/obstacles/obstacle.h"
#include "modules/prediction/container/obstacles/obstacles_container.h" #include "modules/prediction/container/obstacles/obstacles_container.h"
#include "modules/perception/proto/perception_obstacle.pb.h"
namespace apollo { namespace apollo {
namespace prediction { namespace prediction {
class MLPEvaluatorTest : public ::testing::Test { class MLPEvaluatorTest : public ::testing::Test {
public: public:
virtual void SetUp() { void SetUp() override {
std::string file = std::string file =
"modules/prediction/testdata/single_perception_vehicle_onlane.pb.txt"; "modules/prediction/testdata/single_perception_vehicle_onlane.pb.txt";
apollo::common::util::GetProtoFromFile(file, &perception_obstacles_); CHECK(apollo::common::util::GetProtoFromFile(file, &perception_obstacles_));
FLAGS_map_file = "modules/prediction/testdata/kml_map.bin"; FLAGS_map_file = "modules/prediction/testdata/kml_map.bin";
} }
protected: protected:
......
...@@ -69,11 +69,8 @@ Status Prediction::Init() { ...@@ -69,11 +69,8 @@ Status Prediction::Init() {
EvaluatorManager::instance()->Init(prediction_conf_); EvaluatorManager::instance()->Init(prediction_conf_);
PredictorManager::instance()->Init(prediction_conf_); PredictorManager::instance()->Init(prediction_conf_);
CHECK(AdapterManager::GetLocalization()) CHECK(AdapterManager::GetLocalization()) << "Localization is not ready.";
<< "Localization is not ready."; CHECK(AdapterManager::GetPerceptionObstacles()) << "Perception is not ready.";
CHECK(AdapterManager::GetPerceptionObstacles())
<< "Perception is not ready.";
// Set perception obstacle callback function // Set perception obstacle callback function
AdapterManager::SetPerceptionObstaclesCallback(&Prediction::OnPerception, AdapterManager::SetPerceptionObstaclesCallback(&Prediction::OnPerception,
......
...@@ -14,11 +14,12 @@ ...@@ -14,11 +14,12 @@
* limitations under the License. * limitations under the License.
*****************************************************************************/ *****************************************************************************/
#include <limits> #include "modules/prediction/predictor/pedestrian/regional_predictor.h"
#include <cmath> #include <cmath>
#include <limits>
#include <utility> #include <utility>
#include "modules/prediction/predictor/pedestrian/regional_predictor.h"
#include "modules/prediction/common/prediction_util.h" #include "modules/prediction/common/prediction_util.h"
#include "modules/prediction/common/prediction_gflags.h" #include "modules/prediction/common/prediction_gflags.h"
#include "modules/common/math/kalman_filter.h" #include "modules/common/math/kalman_filter.h"
...@@ -154,8 +155,8 @@ void RegionalPredictor::GenerateMovingTrajectory( ...@@ -154,8 +155,8 @@ void RegionalPredictor::GenerateMovingTrajectory(
std::vector<TrajectoryPoint> right_points; std::vector<TrajectoryPoint> right_points;
DrawMovingTrajectory(position, velocity, acc, DrawMovingTrajectory(position, velocity, acc,
obstacle->kf_pedestrian_tracker(), total_time, obstacle->kf_pedestrian_tracker(), total_time,
&left_points, &right_points); &left_points, &right_points);
int start_index = GetTrajectorySize(); int start_index = GetTrajectorySize();
Trajectory left_trajectory; Trajectory left_trajectory;
......
...@@ -32,19 +32,14 @@ int Predictor::GetTrajectorySize() { ...@@ -32,19 +32,14 @@ int Predictor::GetTrajectorySize() {
void Predictor::GenerateTrajectory( void Predictor::GenerateTrajectory(
const std::vector<::apollo::common::TrajectoryPoint>& points, const std::vector<::apollo::common::TrajectoryPoint>& points,
Trajectory* trajectory) { Trajectory* trajectory) {
if (points.size() <= 0) { trajectory->mutable_trajectory_point()->MergeFrom(
return; {points.begin(), points.end()});
}
for (const auto& point : points) {
trajectory->add_trajectory_point()->CopyFrom(point);
}
} }
void Predictor::SetEqualProbability(double probability, int start_index) { void Predictor::SetEqualProbability(double probability, int start_index) {
int num = GetTrajectorySize(); int num = GetTrajectorySize();
if (start_index >= 0 && num > 0 && num > start_index) { if (start_index >= 0 && num > start_index) {
probability = probability / static_cast<double>(num - start_index); probability /= static_cast<double>(num - start_index);
for (int i = start_index; i < num; ++i) { for (int i = start_index; i < num; ++i) {
prediction_obstacle_.mutable_trajectory(i)->set_probability(probability); prediction_obstacle_.mutable_trajectory(i)->set_probability(probability);
} }
......
...@@ -16,10 +16,9 @@ ...@@ -16,10 +16,9 @@
#include "modules/prediction/predictor/predictor_factory.h" #include "modules/prediction/predictor/predictor_factory.h"
#include "modules/prediction/predictor/vehicle/lane_sequence_predictor.h"
#include "modules/prediction/predictor/vehicle/free_move_predictor.h"
#include "modules/common/log.h" #include "modules/common/log.h"
#include "modules/prediction/predictor/vehicle/free_move_predictor.h"
#include "modules/prediction/predictor/vehicle/lane_sequence_predictor.h"
namespace apollo { namespace apollo {
namespace prediction { namespace prediction {
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
* limitations under the License. * limitations under the License.
*****************************************************************************/ *****************************************************************************/
#include <memory>
#include "modules/prediction/predictor/predictor_manager.h" #include "modules/prediction/predictor/predictor_manager.h"
#include <memory>
#include "modules/prediction/predictor/vehicle/lane_sequence_predictor.h" #include "modules/prediction/predictor/vehicle/lane_sequence_predictor.h"
#include "modules/prediction/predictor/vehicle/free_move_predictor.h" #include "modules/prediction/predictor/vehicle/free_move_predictor.h"
#include "modules/prediction/predictor/pedestrian/regional_predictor.h" #include "modules/prediction/predictor/pedestrian/regional_predictor.h"
...@@ -80,15 +80,11 @@ void PredictorManager::Init(const PredictionConf& config) { ...@@ -80,15 +80,11 @@ void PredictorManager::Init(const PredictionConf& config) {
Predictor* PredictorManager::GetPredictor( Predictor* PredictorManager::GetPredictor(
const ObstacleConf::PredictorType& type) { const ObstacleConf::PredictorType& type) {
if (predictors_.find(type) != predictors_.end()) { auto it = predictors_.find(type);
return predictors_[type].get(); return it != predictors_.end() ? it->second.get() : nullptr;
} else {
return nullptr;
}
} }
void PredictorManager::Run( void PredictorManager::Run(const PerceptionObstacles& perception_obstacles) {
const PerceptionObstacles& perception_obstacles) {
prediction_obstacles_.Clear(); prediction_obstacles_.Clear();
ObstaclesContainer *container = dynamic_cast<ObstaclesContainer*>( ObstaclesContainer *container = dynamic_cast<ObstaclesContainer*>(
ContainerManager::instance()->GetContainer( ContainerManager::instance()->GetContainer(
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#ifndef MODULES_PREDICTION_PREDICTOR_PREDICTOR_MANAGER_H_ #ifndef MODULES_PREDICTION_PREDICTOR_PREDICTOR_MANAGER_H_
#define MODULES_PREDICTION_PREDICTOR_PREDICTOR_MANAGER_H_ #define MODULES_PREDICTION_PREDICTOR_PREDICTOR_MANAGER_H_
#include <unordered_map>
#include <map> #include <map>
#include <memory> #include <memory>
...@@ -92,8 +91,7 @@ class PredictorManager { ...@@ -92,8 +91,7 @@ class PredictorManager {
void RegisterPredictors(); void RegisterPredictors();
private: private:
std::map<ObstacleConf::PredictorType, std::map<ObstacleConf::PredictorType, std::unique_ptr<Predictor>> predictors_;
std::unique_ptr<Predictor>> predictors_;
ObstacleConf::PredictorType vehicle_on_lane_predictor_ = ObstacleConf::PredictorType vehicle_on_lane_predictor_ =
ObstacleConf::LANE_SEQUENCE_PREDICTOR; ObstacleConf::LANE_SEQUENCE_PREDICTOR;
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
* limitations under the License. * limitations under the License.
*****************************************************************************/ *****************************************************************************/
#include "modules/prediction/predictor/predictor_manager.h"
#include <string> #include <string>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "modules/prediction/predictor/predictor_manager.h"
#include "modules/prediction/proto/prediction_conf.pb.h" #include "modules/prediction/proto/prediction_conf.pb.h"
#include "modules/common/util/file.h" #include "modules/common/util/file.h"
...@@ -27,7 +27,7 @@ namespace prediction { ...@@ -27,7 +27,7 @@ namespace prediction {
class PredictorManagerTest : public ::testing::Test { class PredictorManagerTest : public ::testing::Test {
public: public:
virtual void SetUp() { void SetUp() override {
manager_ = PredictorManager::instance(); manager_ = PredictorManager::instance();
} }
...@@ -38,10 +38,8 @@ class PredictorManagerTest : public ::testing::Test { ...@@ -38,10 +38,8 @@ class PredictorManagerTest : public ::testing::Test {
TEST_F(PredictorManagerTest, GetPredictor) { TEST_F(PredictorManagerTest, GetPredictor) {
std::string conf_file = "modules/prediction/testdata/prediction_conf.pb.txt"; std::string conf_file = "modules/prediction/testdata/prediction_conf.pb.txt";
bool ret_load_conf = ::apollo::common::util::GetProtoFromFile(conf_file, CHECK(apollo::common::util::GetProtoFromFile(conf_file, &conf_))
&conf_); << "Failed to load " << conf_file;
EXPECT_TRUE(ret_load_conf);
EXPECT_TRUE(conf_.IsInitialized());
manager_->Init(conf_); manager_->Init(conf_);
......
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
#include "modules/prediction/predictor/vehicle/free_move_predictor.h" #include "modules/prediction/predictor/vehicle/free_move_predictor.h"
#include <vector>
#include <cmath> #include <cmath>
#include <limits> #include <limits>
#include <utility> #include <utility>
#include <vector>
#include "Eigen/Dense" #include "Eigen/Dense"
#include "modules/prediction/common/prediction_gflags.h" #include "modules/prediction/common/prediction_gflags.h"
...@@ -114,13 +114,13 @@ void FreeMovePredictor::DrawFreeMoveTrajectoryPoints( ...@@ -114,13 +114,13 @@ void FreeMovePredictor::DrawFreeMoveTrajectoryPoints(
for (size_t i = 0; i < static_cast<size_t>(total_time / freq); ++i) { for (size_t i = 0; i < static_cast<size_t>(total_time / freq); ++i) {
double speed = std::hypot(v_x, v_y); double speed = std::hypot(v_x, v_y);
if (speed <= std::numeric_limits<double>::epsilon()) { if (speed <= std::numeric_limits<double>::epsilon()) {
speed = 0.0; speed = 0.0;
v_x = 0.0; v_x = 0.0;
v_y = 0.0; v_y = 0.0;
acc_x = 0.0; acc_x = 0.0;
acc_y = 0.0; acc_y = 0.0;
} else if (speed > FLAGS_max_speed) { } else if (speed > FLAGS_max_speed) {
speed = FLAGS_max_speed; speed = FLAGS_max_speed;
} }
// update theta // update theta
......
...@@ -36,7 +36,7 @@ class FreeMovePredictorTest : public ::testing::Test { ...@@ -36,7 +36,7 @@ class FreeMovePredictorTest : public ::testing::Test {
virtual void SetUp() { virtual void SetUp() {
std::string file = std::string file =
"modules/prediction/testdata/single_perception_vehicle_offlane.pb.txt"; "modules/prediction/testdata/single_perception_vehicle_offlane.pb.txt";
apollo::common::util::GetProtoFromFile(file, &perception_obstacles_); CHECK(apollo::common::util::GetProtoFromFile(file, &perception_obstacles_));
FLAGS_map_file = "modules/prediction/testdata/kml_map.bin"; FLAGS_map_file = "modules/prediction/testdata/kml_map.bin";
FLAGS_p_var = 0.1; FLAGS_p_var = 0.1;
FLAGS_q_var = 0.01; FLAGS_q_var = 0.01;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册