未验证 提交 b40b6e93 编写于 作者: C changsh726 提交者: GitHub

Bazel: make modules/prediction build passed. (#11394)

上级 ccb6ebf2
......@@ -11,6 +11,11 @@ cc_library(
linkstatic = False,
linkopts = [
"-L/usr/local/libtorch_cpu/lib",
"-lc10",
"-ltorch",
"-ltorch_cpu",
],
deps = [
"@python3",
],
)
......@@ -11,7 +11,12 @@ cc_library(
linkstatic = False,
linkopts = [
"-L/usr/local/libtorch_gpu/lib",
"-lc10",
"-ltorch",
"-ltorch_cpu",
"-ltorch_cuda",
],
deps = [
"@python3",
],
)
......@@ -16,7 +16,7 @@ cc_library(
"//modules/prediction/common:message_process",
"//modules/prediction/evaluator:evaluator_manager",
"//modules/prediction/predictor:predictor_manager",
"//modules/prediction/proto:offline_features_proto",
"//modules/prediction/proto:offline_features_cc_proto",
"//modules/prediction/scenario:scenario_manager",
"//modules/prediction/submodules:evaluator_submodule_lib",
"//modules/prediction/submodules:predictor_submodule_lib",
......@@ -35,6 +35,9 @@ cc_test(
":prediction_data",
":prediction_testdata",
],
linkopts = [
"-lgomp",
],
deps = [
":prediction_component_lib",
],
......
......@@ -83,8 +83,8 @@ cc_library(
"//modules/common/util",
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/container/obstacles:obstacle",
"//modules/prediction/proto:offline_features_proto",
"//modules/prediction/proto:prediction_proto",
"//modules/prediction/proto:offline_features_cc_proto",
"//modules/prediction/proto:prediction_obstacle_cc_proto",
],
)
......@@ -110,7 +110,7 @@ cc_library(
"//modules/prediction/common:prediction_constants",
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/common:prediction_system_gflags",
"//modules/prediction/proto:lane_graph_proto",
"//modules/prediction/proto:lane_graph_cc_proto",
],
)
......@@ -141,7 +141,7 @@ cc_library(
deps = [
"//modules/common/math",
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/proto:lane_graph_proto",
"//modules/prediction/proto:lane_graph_cc_proto",
],
)
......@@ -165,7 +165,7 @@ cc_library(
hdrs = ["environment_features.h"],
deps = [
"//cyber",
"//modules/common/proto:geometry_proto",
"//modules/common/proto:geometry_cc_proto",
],
)
......@@ -188,7 +188,7 @@ cc_library(
],
deps = [
":prediction_map",
"//modules/prediction/proto:feature_proto",
"//modules/prediction/proto:feature_cc_proto",
],
)
......@@ -221,7 +221,7 @@ cc_library(
"//modules/common/adapters:adapter_gflags",
"//modules/prediction/evaluator:evaluator_manager",
"//modules/prediction/predictor:predictor_manager",
"//modules/prediction/proto:offline_features_proto",
"//modules/prediction/proto:offline_features_cc_proto",
"//modules/prediction/scenario:scenario_manager",
"//modules/prediction/util:data_extraction",
],
......@@ -264,7 +264,7 @@ cc_library(
"//modules/common/util",
"//modules/prediction/container:container_manager",
"//modules/prediction/container/pose:pose_container",
"//modules/prediction/proto:feature_proto",
"//modules/prediction/proto:feature_cc_proto",
"@opencv",
],
)
......
......@@ -21,6 +21,7 @@
#include "cyber/common/file.h"
#include "cyber/record/record_reader.h"
#include "cyber/record/record_writer.h"
#include "modules/common/adapters/adapter_gflags.h"
#include "modules/prediction/common/feature_output.h"
#include "modules/prediction/common/junction_analyzer.h"
......@@ -297,36 +298,23 @@ void MessageProcess::ProcessOfflineData(
message.channel_name, perception_obstacles, message.time);
}
PredictionObstacles prediction_obstacles;
OnPerception(perception_obstacles, &prediction_obstacles);
if (FLAGS_prediction_offline_mode == PredictionConstants::kDumpRecord) {
SingleMessage single_message;
std::string content = "";
prediction_obstacles.SerializeToString(&content);
single_message.set_content(content);
single_message.set_time(message.time);
single_message.set_channel_name(FLAGS_prediction_topic);
writer.WriteMessage(RecordMessageToSingleMessage(message));
=======
OnPerception(perception_obstacles, container_manager,
evaluator_manager, predictor_manager, scenario_manager,
OnPerception(perception_obstacles, container_manager, evaluator_manager,
predictor_manager, scenario_manager,
&prediction_obstacles);
if (FLAGS_prediction_offline_mode ==
PredictionConstants::kDumpRecord) {
if (FLAGS_prediction_offline_mode == PredictionConstants::kDumpRecord) {
writer.WriteMessage<PredictionObstacles>(
prediction_conf.topic_conf().perception_obstacle_topic(),
prediction_obstacles, message.time);
AINFO << "Generated a new prediction message.";
>>>>>>> master
}
}
} else if (message.channel_name ==
prediction_conf.topic_conf().localization_topic()) {
LocalizationEstimate localization;
if (localization.ParseFromString(message.content)) {
if (FLAGS_prediction_offline_mode ==
PredictionConstants::kDumpRecord) {
writer.WriteMessage<LocalizationEstimate>(
message.channel_name, localization, message.time);
if (FLAGS_prediction_offline_mode == PredictionConstants::kDumpRecord) {
writer.WriteMessage<LocalizationEstimate>(message.channel_name,
localization, message.time);
}
OnLocalization(container_manager.get(), localization);
}
......@@ -341,7 +329,7 @@ void MessageProcess::ProcessOfflineData(
if (FLAGS_prediction_offline_mode == PredictionConstants::kDumpRecord) {
writer.Close();
}
}
}
} // namespace prediction
} // namespace prediction
} // namespace apollo
......@@ -11,7 +11,7 @@ cc_library(
"-DMODULE_NAME=\\\"prediction\\\"",
],
deps = [
"//modules/common/adapters/proto:adapter_config_proto",
"//modules/common/adapters/proto:adapter_config_cc_proto",
"//modules/prediction/container/adc_trajectory:adc_trajectory_container",
"//modules/prediction/container/obstacles:obstacles_container",
"//modules/prediction/container/pose:pose_container",
......
......@@ -11,9 +11,10 @@ cc_library(
"-DMODULE_NAME=\\\"prediction\\\"",
],
deps = [
"//modules/planning/proto:planning_proto",
"//modules/planning/proto:planning_cc_proto",
"//modules/prediction/common:prediction_map",
"//modules/prediction/container",
"//modules/prediction/proto:lane_graph_cc_proto",
],
)
......
......@@ -16,7 +16,7 @@ cc_library(
"//modules/prediction/common:prediction_constants",
"//modules/prediction/container",
"//modules/prediction/container/obstacles:obstacle",
"//modules/prediction/proto:prediction_proto",
"//modules/prediction/proto:prediction_obstacle_cc_proto",
"//modules/prediction/submodules:submodule_output",
],
)
......@@ -33,8 +33,8 @@ cc_library(
"//modules/prediction/common:junction_analyzer",
"//modules/prediction/container/obstacles:obstacle_clusters",
"//modules/prediction/network/rnn_model",
"//modules/prediction/proto:prediction_conf_proto",
"//modules/prediction/proto:prediction_proto",
"//modules/prediction/proto:prediction_conf_cc_proto",
"//modules/prediction/proto:prediction_obstacle_cc_proto",
],
)
......@@ -77,7 +77,7 @@ cc_library(
],
deps = [
"//modules/prediction/common:road_graph",
"//modules/prediction/proto:feature_proto",
"//modules/prediction/proto:feature_cc_proto",
],
)
......
......@@ -12,8 +12,8 @@ cc_library(
],
deps = [
"//modules/common/math:quaternion",
"//modules/localization/proto:localization_proto",
"//modules/perception/proto:perception_proto",
"//modules/localization/proto:localization_cc_proto",
"//modules/perception/proto:perception_obstacle_cc_proto",
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/container",
],
......
......@@ -13,7 +13,7 @@ cc_library(
deps = [
"//modules/prediction/common:prediction_map",
"//modules/prediction/container",
"//modules/storytelling/proto:story_proto",
"//modules/storytelling/proto:story_cc_proto",
],
)
......
......@@ -26,7 +26,8 @@ cc_library(
"//modules/prediction/evaluator/vehicle:lane_scanning_evaluator",
"//modules/prediction/evaluator/vehicle:mlp_evaluator",
"//modules/prediction/evaluator/vehicle:semantic_lstm_evaluator",
"//modules/prediction/proto:prediction_conf_proto",
"//modules/prediction/proto:prediction_conf_cc_proto",
"//third_party:libtorch",
],
)
......@@ -38,6 +39,9 @@ cc_test(
"//modules/prediction:prediction_data",
"//modules/prediction:prediction_testdata",
],
linkopts = [
"-lgomp",
],
deps = [
"//modules/prediction/common:kml_map_based_test",
"//modules/prediction/evaluator:evaluator_manager",
......
......@@ -16,7 +16,7 @@ cc_library(
"//modules/prediction/common:validation_checker",
"//modules/prediction/container/obstacles:obstacles_container",
"//modules/prediction/evaluator",
"//modules/prediction/proto:fnn_vehicle_model_proto",
"//modules/prediction/proto:fnn_vehicle_model_cc_proto",
],
)
......@@ -70,6 +70,7 @@ cc_library(
hdrs = ["junction_mlp_evaluator.h"],
copts = [
"-DMODULE_NAME=\\\"prediction\\\"",
"-fopenmp",
],
deps = [
"//modules/common/math:geometry",
......@@ -90,6 +91,9 @@ cc_test(
"//modules/prediction:prediction_data",
"//modules/prediction:prediction_testdata",
],
linkopts = [
"-lgomp",
],
deps = [
"//modules/prediction/common:kml_map_based_test",
"//modules/prediction/evaluator/vehicle:junction_mlp_evaluator",
......@@ -103,6 +107,7 @@ cc_library(
hdrs = ["junction_map_evaluator.h"],
copts = [
"-DMODULE_NAME=\\\"prediction\\\"",
"-fopenmp",
],
deps = [
"//modules/prediction/common:prediction_util",
......@@ -119,6 +124,7 @@ cc_library(
hdrs = ["cruise_mlp_evaluator.h"],
copts = [
"-DMODULE_NAME=\\\"prediction\\\"",
"-fopenmp",
],
deps = [
"//modules/prediction/common:prediction_util",
......@@ -137,6 +143,9 @@ cc_test(
"//modules/prediction:prediction_data",
"//modules/prediction:prediction_testdata",
],
linkopts = [
"-lgomp",
],
deps = [
"//modules/prediction/common:kml_map_based_test",
"//modules/prediction/evaluator/vehicle:cruise_mlp_evaluator",
......@@ -149,6 +158,7 @@ cc_library(
hdrs = ["lane_scanning_evaluator.h"],
copts = [
"-DMODULE_NAME=\\\"prediction\\\"",
"-fopenmp",
],
deps = [
"//modules/prediction/container:container_manager",
......@@ -179,6 +189,7 @@ cc_library(
hdrs = ["semantic_lstm_evaluator.h"],
copts = [
"-DMODULE_NAME=\\\"prediction\\\"",
"-fopenmp",
],
deps = [
"//modules/prediction/common:prediction_util",
......
......@@ -94,11 +94,12 @@ bool JunctionMapEvaluator::Evaluate(Obstacle* obstacle_ptr,
junction_exit_mask[0][i] = static_cast<float>(feature_values[i]);
}
torch_inputs.push_back(c10::ivalue::Tuple::create(
at::Tensor torch_input_tensor;
torch_inputs.push_back(c10::ivalue::Tuple::createNamed(
{std::move(img_tensor.to(device_)),
std::move(junction_exit_mask.to(device_))},
c10::TupleType::create(
std::vector<c10::TypePtr>(2, c10::TensorType::create()))));
c10::TupleType::create(std::vector<c10::TypePtr>(
2, c10::TensorType::create(torch_input_tensor)))));
// Compute probability
std::vector<double> probability;
......
......@@ -98,11 +98,13 @@ bool SemanticLSTMEvaluator::Evaluate(Obstacle* obstacle_ptr,
// Build input features for torch
std::vector<torch::jit::IValue> torch_inputs;
torch_inputs.push_back(c10::ivalue::Tuple::create(
at::Tensor torch_input_tensor;
torch_inputs.push_back(c10::ivalue::Tuple::createNamed(
{std::move(img_tensor.to(device_)), std::move(obstacle_pos.to(device_)),
std::move(obstacle_pos_step.to(device_))},
c10::TupleType::create(
std::vector<c10::TypePtr>(3, c10::TensorType::create()))));
c10::TupleType::create(std::vector<c10::TypePtr>(
3, c10::TensorType::create(torch_input_tensor)))));
// Compute pred_traj
std::vector<double> pred_traj;
......@@ -110,8 +112,9 @@ bool SemanticLSTMEvaluator::Evaluate(Obstacle* obstacle_ptr,
auto start_time = std::chrono::system_clock::now();
at::Tensor torch_output_tensor = torch_default_output_tensor_;
if (obstacle_ptr->IsPedestrian()) {
torch_output_tensor = torch_pedestrian_model_.forward(torch_inputs).
toTensor().to(torch::kCPU);
torch_output_tensor = torch_pedestrian_model_.forward(torch_inputs)
.toTensor()
.to(torch::kCPU);
} else {
torch_output_tensor =
torch_vehicle_model_.forward(torch_inputs).toTensor().to(torch::kCPU);
......@@ -166,8 +169,8 @@ bool SemanticLSTMEvaluator::Evaluate(Obstacle* obstacle_ptr,
rotation_matrix(1, 1) = std::cos(heading);
Eigen::Matrix2d cov_matrix;
cov_matrix = rotation_matrix * cov_matrix_r *
(rotation_matrix.transpose());
cov_matrix =
rotation_matrix * cov_matrix_r * (rotation_matrix.transpose());
double sigma_x = std::sqrt(std::abs(cov_matrix(0, 0)));
double sigma_y = std::sqrt(std::abs(cov_matrix(1, 1)));
double corr = cov_matrix(0, 1) / (sigma_x + FLAGS_double_precision) /
......@@ -258,11 +261,13 @@ void SemanticLSTMEvaluator::LoadModel() {
torch::Tensor obstacle_pos = torch::zeros({1, 20, 2});
torch::Tensor obstacle_pos_step = torch::zeros({1, 20, 2});
std::vector<torch::jit::IValue> torch_inputs;
torch_inputs.push_back(c10::ivalue::Tuple::create(
at::Tensor torch_input_tensor;
torch_inputs.push_back(c10::ivalue::Tuple::createNamed(
{std::move(img_tensor.to(device_)), std::move(obstacle_pos.to(device_)),
std::move(obstacle_pos_step.to(device_))},
c10::TupleType::create(
std::vector<c10::TypePtr>(3, c10::TensorType::create()))));
c10::TupleType::create(std::vector<c10::TypePtr>(
3, c10::TensorType::create(torch_input_tensor)))));
// Run one inference to avoid very slow first inference later
torch_default_output_tensor_ =
torch_vehicle_model_.forward(torch_inputs).toTensor().to(torch::kCPU);
......
......@@ -8,6 +8,8 @@ cc_library(
srcs = ["rnn_model.cc"],
hdrs = ["rnn_model.h"],
deps = [
"//cyber/base:macros",
"//cyber/common:macros",
"//modules/prediction/network:net_model",
],
)
......@@ -20,6 +22,7 @@ cc_test(
"//modules/prediction:prediction_data",
],
deps = [
"//cyber/common:file",
"//modules/prediction/network/rnn_model",
"@com_google_googletest//:gtest_main",
],
......
......@@ -9,6 +9,9 @@ cc_binary(
copts = [
"-DMODULE_NAME=\\\"prediction\\\"",
],
linkopts = [
"-lgomp",
],
deps = [
"//modules/prediction/common:message_process",
"//third_party:boost",
......
......@@ -21,7 +21,7 @@ cc_library(
"//modules/prediction/predictor/lane_sequence:lane_sequence_predictor",
"//modules/prediction/predictor/move_sequence:move_sequence_predictor",
"//modules/prediction/predictor/single_lane:single_lane_predictor",
"//modules/prediction/proto:prediction_conf_proto",
"//modules/prediction/proto:prediction_conf_cc_proto",
"//modules/prediction/scenario:scenario_manager",
],
)
......@@ -34,6 +34,9 @@ cc_test(
"//modules/prediction:prediction_data",
"//modules/prediction:prediction_testdata",
],
linkopts = [
"-lgomp",
],
deps = [
"//modules/prediction/common:kml_map_based_test",
"//modules/prediction/evaluator:evaluator_manager",
......
......@@ -18,7 +18,7 @@ cc_library(
"//modules/prediction/container/obstacles:obstacle_clusters",
"//modules/prediction/container/obstacles:obstacles_container",
"//modules/prediction/predictor/sequence:sequence_predictor",
"//modules/prediction/proto:lane_graph_proto",
"//modules/prediction/proto:lane_graph_cc_proto",
],
)
......
......@@ -11,7 +11,7 @@ cc_library(
"-DMODULE_NAME=\\\"prediction\\\"",
],
deps = [
"//modules/common/adapters/proto:adapter_config_proto",
"//modules/common/adapters/proto:adapter_config_cc_proto",
"//modules/prediction/common:feature_output",
"//modules/prediction/common:prediction_util",
"//modules/prediction/predictor/sequence:sequence_predictor",
......
......@@ -24,6 +24,9 @@ cc_test(
"//modules/prediction:prediction_data",
"//modules/prediction:prediction_testdata",
],
linkopts = [
"-lgomp",
],
deps = [
"//modules/prediction/common:kml_map_based_test",
"//modules/prediction/container/obstacles:obstacles_container",
......
......@@ -11,7 +11,7 @@ cc_library(
"-DMODULE_NAME=\\\"prediction\\\"",
],
deps = [
"//modules/prediction/proto:scenario_proto",
"//modules/prediction/proto:scenario_cc_proto",
],
)
......
......@@ -12,7 +12,7 @@ cc_library(
],
deps = [
"//modules/common/util:lru_cache",
"//modules/perception/proto:perception_proto",
"//modules/perception/proto:perception_obstacle_cc_proto",
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/container/obstacles:obstacle",
"@com_google_absl//absl/time",
......@@ -29,9 +29,9 @@ cc_library(
deps = [
"//cyber",
"//modules/common/adapters:adapter_gflags",
"//modules/common/adapters/proto:adapter_config_proto",
"//modules/common/adapters/proto:adapter_config_cc_proto",
"//modules/common/time",
"//modules/perception/proto:perception_proto",
"//modules/perception/proto:perception_obstacle_cc_proto",
"//modules/prediction/common:message_process",
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/evaluator:evaluator_manager",
......@@ -56,15 +56,15 @@ cc_library(
deps = [
"//cyber",
"//modules/common/adapters:adapter_gflags",
"//modules/common/adapters/proto:adapter_config_proto",
"//modules/common/adapters/proto:adapter_config_cc_proto",
"//modules/common/time",
"//modules/common/util:message_util",
"//modules/perception/proto:perception_proto",
"//modules/perception/proto:perception_obstacle_cc_proto",
"//modules/prediction/common:message_process",
"//modules/prediction/common:prediction_gflags",
"//modules/prediction/container/adc_trajectory:adc_trajectory_container",
"//modules/prediction/predictor:predictor_manager",
"//modules/prediction/proto:prediction_proto",
"//modules/prediction/proto:prediction_obstacle_cc_proto",
],
alwayslink = True,
)
......
......@@ -16,10 +16,10 @@
#pragma once
#include <boost/filesystem.hpp>
#include <boost/range/iterator_range.hpp>
#include <string>
#include <vector>
#include <boost/filesystem.hpp>
#include <boost/range/iterator_range.hpp>
namespace apollo {
namespace prediction {
......
......@@ -54,7 +54,8 @@ bazel_build_with_dist_cache \
//modules/v2x/... \
//modules/dreamview/... \
//modules/guardian/... \
//modules/localization/...
//modules/localization/... \
//modules/prediction/...
bazel_test_with_dist_cache \
//cyber/... \
......@@ -77,6 +78,7 @@ bazel_test_with_dist_cache //modules/drivers/...
bash scripts/install_esdcan_library.sh uninstall
bazel_build_with_dist_cache //modules/tools/...
bazel build //modules/tools/...
# Note(storypku): bazel test works except some lint errors in cyber_visualizer.
# Check cyber_visualizer's functionality once stablized.
bazel_test_with_dist_cache $(bazel query //modules/tools/... except //modules/tools/visualizer/...)
......@@ -94,7 +96,7 @@ echo "########################### All check passed! ###########################"
# TODO(?): bazel test //modules/map/...
# TODO(?): bazel build //modules/contrib/...
# TODO(?): bazel build //modules/perception/...
# TODO(changsh726): bazel build //modules/prediction/...
# TODO(?): bazel test //modules/prediction/...
# TODO(?): bazel build //modules/third_party_perception/...
# TODO(?): apollo.sh build
# TODO(?): apollo.sh test
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册