提交 0823cf39 编写于 作者: P panjiacheng 提交者: HongyiSun

Prediction: encode all lane features and put them into a vector.

上级 16b6d7f3
......@@ -167,7 +167,11 @@ DEFINE_string(torch_pedestrian_interaction_prediction_layer_file,
DEFINE_string(torch_lane_aggregating_obstacle_encoding_file,
"/apollo/modules/prediction/data/"
"lane_aggregating_obstacle_encoding_layer.pt",
"");
"lane aggregating obstacle encoding layer");
DEFINE_string(torch_lane_aggregating_lane_encoding_file,
"/apollo/modules/prediction/data/"
"lane_aggregating_lane_encoding_layer.pt",
"lane aggregating lane encoding layer");
DEFINE_int32(max_num_obstacles, 300,
"maximal number of obstacles stored in obstacles container.");
DEFINE_double(valid_position_diff_threshold, 0.5,
......
......@@ -96,6 +96,7 @@ DECLARE_string(torch_pedestrian_interaction_social_embedding_file);
DECLARE_string(torch_pedestrian_interaction_single_lstm_file);
DECLARE_string(torch_pedestrian_interaction_prediction_layer_file);
DECLARE_string(torch_lane_aggregating_obstacle_encoding_file);
DECLARE_string(torch_lane_aggregating_lane_encoding_file);
DECLARE_string(evaluator_vehicle_rnn_file);
DECLARE_string(evaluator_vehicle_cruise_mlp_file);
DECLARE_int32(max_num_obstacles);
......
......@@ -47,6 +47,8 @@ void LaneAggregatingEvaluator::LoadModel() {
torch::set_num_threads(1);
torch_obstacle_encoding_ptr_ = torch::jit::load(
FLAGS_torch_lane_aggregating_obstacle_encoding_file, device_);
torch_lane_encoding_ptr_ = torch::jit::load(
FLAGS_torch_lane_aggregating_lane_encoding_file, device_);
}
bool LaneAggregatingEvaluator::Evaluate(Obstacle* obstacle_ptr) {
......@@ -100,6 +102,32 @@ bool LaneAggregatingEvaluator::Evaluate(Obstacle* obstacle_ptr) {
->forward(obstacle_encoding_inputs).toTensor().to(torch::kCPU);
// 2. Encode the lane features.
std::vector<std::vector<double>> lane_feature_values;
std::vector<int> lane_sequence_idx_to_remove;
if (!ExtractStaticEnvFeatures(obstacle_ptr, lane_graph_ptr,
&lane_feature_values, &lane_sequence_idx_to_remove)) {
AERROR << "Failed to extract static environmental features around obs_id = "
<< id;
}
std::vector<torch::Tensor> lane_encoding_list;
for (const auto& single_lane_feature_values : lane_feature_values) {
if (single_lane_feature_values.size() != SINGLE_LANE_FEATURE_SIZE) {
AERROR << "Obstacle [" << id << "] has incorrect lane feature size.";
return false;
}
std::vector<torch::jit::IValue> single_lane_encoding_inputs;
torch::Tensor single_lane_encoding_inputs_tensor =
torch::zeros({1, static_cast<int>(single_lane_feature_values.size())});
for (size_t i = 0; i < single_lane_feature_values.size(); ++i) {
single_lane_encoding_inputs_tensor[0][i] = static_cast<float>(
single_lane_feature_values[i]);
}
single_lane_encoding_inputs.push_back(
std::move(single_lane_encoding_inputs_tensor));
torch::Tensor single_lane_encoding = torch_lane_encoding_ptr_
->forward(single_lane_encoding_inputs).toTensor().to(torch::kCPU);
lane_encoding_list.push_back(std::move(single_lane_encoding));
}
// 3. Aggregate the lane features.
......
......@@ -92,6 +92,8 @@ class LaneAggregatingEvaluator : public Evaluator {
// torch_single_lstm_ptr_ = nullptr;
std::shared_ptr<torch::jit::script::Module>
torch_obstacle_encoding_ptr_ = nullptr;
std::shared_ptr<torch::jit::script::Module>
torch_lane_encoding_ptr_ = nullptr;
torch::Device device_;
static const size_t OBSTACLE_FEATURE_SIZE = 20 * 9;\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册