提交 16b6d7f3 编写于 作者: P panjiacheng 提交者: HongyiSun

Prediction: encode obstacle features in lane_aggregating_evaluator.

上级 137b92ae
......@@ -164,6 +164,10 @@ DEFINE_string(torch_pedestrian_interaction_prediction_layer_file,
"/apollo/modules/prediction/data/"
"pedestrian_interaction_prediction_layer.pt",
"pedestrian interaction prediction layer");
DEFINE_string(torch_lane_aggregating_obstacle_encoding_file,
"/apollo/modules/prediction/data/"
"lane_aggregating_obstacle_encoding_layer.pt",
"");
DEFINE_int32(max_num_obstacles, 300,
"maximal number of obstacles stored in obstacles container.");
DEFINE_double(valid_position_diff_threshold, 0.5,
......
......@@ -95,6 +95,7 @@ DECLARE_string(torch_pedestrian_interaction_position_embedding_file);
DECLARE_string(torch_pedestrian_interaction_social_embedding_file);
DECLARE_string(torch_pedestrian_interaction_single_lstm_file);
DECLARE_string(torch_pedestrian_interaction_prediction_layer_file);
DECLARE_string(torch_lane_aggregating_obstacle_encoding_file);
DECLARE_string(evaluator_vehicle_rnn_file);
DECLARE_string(evaluator_vehicle_cruise_mlp_file);
DECLARE_int32(max_num_obstacles);
......
......@@ -40,7 +40,13 @@ using apollo::perception::PerceptionObstacles;
LaneAggregatingEvaluator::LaneAggregatingEvaluator()
: device_(torch::kCPU) {
// LoadModel();
LoadModel();
}
void LaneAggregatingEvaluator::LoadModel() {
torch::set_num_threads(1);
torch_obstacle_encoding_ptr_ = torch::jit::load(
FLAGS_torch_lane_aggregating_obstacle_encoding_file, device_);
}
bool LaneAggregatingEvaluator::Evaluate(Obstacle* obstacle_ptr) {
......@@ -69,6 +75,35 @@ bool LaneAggregatingEvaluator::Evaluate(Obstacle* obstacle_ptr) {
<< " lane sequences to scan.";
// Extract features, and do model inferencing.
// 1. Encode the obstacle features.
std::vector<double> obstacle_feature_values;
if (!ExtractObstacleFeatures(obstacle_ptr, &obstacle_feature_values)) {
ADEBUG << "Failed to extract obstacle features for obs_id = " << id;
}
if (obstacle_feature_values.size() != OBSTACLE_FEATURE_SIZE) {
ADEBUG << "Obstacle [" << id << "] has fewer than "
<< "expected obstacle feature_values "
<< obstacle_feature_values.size() << ".";
return false;
}
ADEBUG << "Obstacle feature size = " << obstacle_feature_values.size();
std::vector<torch::jit::IValue> obstacle_encoding_inputs;
torch::Tensor obstacle_encoding_inputs_tensor =
torch::zeros({1, static_cast<int>(obstacle_feature_values.size())});
for (size_t i = 0; i < obstacle_feature_values.size(); ++i) {
obstacle_encoding_inputs_tensor[0][i] = static_cast<float>(
obstacle_feature_values[i]);
}
obstacle_encoding_inputs.push_back(
std::move(obstacle_encoding_inputs_tensor));
torch::Tensor obstalce_encoding = torch_obstacle_encoding_ptr_
->forward(obstacle_encoding_inputs).toTensor().to(torch::kCPU);
// 2. Encode the lane features.
// 3. Aggregate the lane features.
// 4. Make prediction.
return true;
}
......
......@@ -80,15 +80,7 @@ class LaneAggregatingEvaluator : public Evaluator {
std::vector<std::vector<double>>* feature_values,
std::vector<int>* lane_sequence_idx_to_remove);
// struct LSTMState {
// double timestamp;
// torch::Tensor ct;
// torch::Tensor ht;
// };
// void Clear();
// void LoadModel();
void LoadModel();
private:
// std::unordered_map<int, LSTMState> obstacle_id_lstm_state_map_;
......@@ -98,8 +90,8 @@ class LaneAggregatingEvaluator : public Evaluator {
// torch_social_embedding_ptr_ = nullptr;
// std::shared_ptr<torch::jit::script::Module>
// torch_single_lstm_ptr_ = nullptr;
// std::shared_ptr<torch::jit::script::Module>
// torch_prediction_layer_ptr_ = nullptr;
std::shared_ptr<torch::jit::script::Module>
torch_obstacle_encoding_ptr_ = nullptr;
torch::Device device_;
static const size_t OBSTACLE_FEATURE_SIZE = 20 * 9;\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册