提交 a1425793 编写于 作者: K kechxu 提交者: Xiangquan Xiao

Prediction: load models

上级 51ca837a
......@@ -148,6 +148,22 @@ DEFINE_string(torch_vehicle_cruise_cutin_file,
DEFINE_string(torch_vehicle_lane_scanning_file,
"/apollo/modules/prediction/data/lane_scanning_vehicle_model.pt",
"Vehicle lane scanning model file");
DEFINE_string(torch_pedestrian_interaction_position_embedding_file,
"/apollo/modules/prediction/data/"
"pedestrian_interaction_position_embedding.pt",
"pedestrian interaction position embedding");
DEFINE_string(torch_pedestrian_interaction_social_embedding_file,
"/apollo/modules/prediction/data/"
"pedestrian_interaction_social_embedding.pt",
"pedestrian interaction social embedding");
DEFINE_string(torch_pedestrian_interaction_single_lstm_file,
"/apollo/modules/prediction/data/"
"pedestrian_interaction_single_lstm.pt",
"pedestrian interaction single lstm");
DEFINE_string(torch_pedestrian_interaction_prediction_layer_file,
"/apollo/modules/prediction/data/"
"pedestrian_interaction_prediction_layer.pt",
"pedestrian interaction prediction layer");
DEFINE_int32(max_num_obstacles, 300,
"maximal number of obstacles stored in obstacles container.");
DEFINE_double(valid_position_diff_threshold, 0.5,
......
......@@ -91,6 +91,10 @@ DECLARE_string(torch_vehicle_junction_map_file);
DECLARE_string(torch_vehicle_cruise_go_file);
DECLARE_string(torch_vehicle_cruise_cutin_file);
DECLARE_string(torch_vehicle_lane_scanning_file);
DECLARE_string(torch_pedestrian_interaction_position_embedding_file);
DECLARE_string(torch_pedestrian_interaction_social_embedding_file);
DECLARE_string(torch_pedestrian_interaction_single_lstm_file);
DECLARE_string(torch_pedestrian_interaction_prediction_layer_file);
DECLARE_string(evaluator_vehicle_rnn_file);
DECLARE_string(evaluator_vehicle_cruise_mlp_file);
DECLARE_int32(max_num_obstacles);
......
......@@ -32,6 +32,9 @@ using apollo::common::adapter::AdapterConfig;
using apollo::perception::PerceptionObstacle;
using apollo::perception::PerceptionObstacles;
PedestrianInteractionEvaluator::PedestrianInteractionEvaluator()
: device_(torch::kCPU) { }
void PedestrianInteractionEvaluator::Clear() {
auto ptr_obstacles_container =
ContainerManager::Instance()->GetContainer<ObstaclesContainer>(
......@@ -49,6 +52,18 @@ void PedestrianInteractionEvaluator::Clear() {
}
}
void PedestrianInteractionEvaluator::LoadModel() {
torch::set_num_threads(1);
torch_position_embedding_ptr_ = torch::jit::load(
FLAGS_torch_pedestrian_interaction_position_embedding_file, device_);
torch_social_embedding_ptr_ = torch::jit::load(
FLAGS_torch_pedestrian_interaction_social_embedding_file, device_);
torch_single_lstm_ptr_ = torch::jit::load(
FLAGS_torch_pedestrian_interaction_single_lstm_file, device_);
torch_prediction_layer_ptr_ = torch::jit::load(
FLAGS_torch_pedestrian_interaction_prediction_layer_file, device_);
}
bool PedestrianInteractionEvaluator::Evaluate(Obstacle* obstacle_ptr) {
// Sanity checks.
CHECK_NOTNULL(obstacle_ptr);
......
......@@ -21,6 +21,7 @@
#pragma once
#include <memory>
#include <unordered_map>
#include <string>
#include <vector>
......@@ -42,7 +43,7 @@ class PedestrianInteractionEvaluator : public Evaluator {
/**
* @brief Constructor
*/
PedestrianInteractionEvaluator() = default;
PedestrianInteractionEvaluator();
/**
* @brief Destructor
......@@ -77,8 +78,19 @@ class PedestrianInteractionEvaluator : public Evaluator {
void Clear();
void LoadModel();
private:
std::unordered_map<int, LSTMState> obstacle_id_lstm_state_map_;
std::shared_ptr<torch::jit::script::Module>
torch_position_embedding_ptr_ = nullptr;
std::shared_ptr<torch::jit::script::Module>
torch_social_embedding_ptr_ = nullptr;
std::shared_ptr<torch::jit::script::Module>
torch_single_lstm_ptr_ = nullptr;
std::shared_ptr<torch::jit::script::Module>
torch_prediction_layer_ptr_ = nullptr;
torch::Device device_;
};
} // namespace prediction
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册