提交 88864fe4 编写于 作者: K kechxu 提交者: Jiangtao Hu

Prediction: more clear structure of cruise model

上级 fa160405
......@@ -15,14 +15,45 @@
*****************************************************************************/
#include "modules/prediction/network/cruise_model/cruise_model.h"
#include "cyber/common/log.h"
namespace apollo {
namespace prediction {
namespace network {
using apollo::prediction::CruiseModelParameter;
void CruiseModel::Run(const std::vector<Eigen::MatrixXf>& inputs,
Eigen::MatrixXf* output) const {
// TODO(kechxu) implement
// Step 1: Run lane feature conv 1d
// Step 2: Run lane feature max pool 1d
// Step 3: Run lane feature avg pool 1d
// Step 4: Run obstacle feature fully connected
// Step 5: Concatenate [lane_feature, obstacle_feature]
// Step 6: Get classification result
// Step 7: Get regression result
// Step 8: Output
}
bool CruiseModel::LoadModel(
const CruiseModelParameter& cruise_model_parameter) {
CHECK(cruise_model_parameter.has_lane_feature_conv());
CHECK(cruise_model_parameter.has_lane_feature_maxpool());
CHECK(cruise_model_parameter.has_lane_feature_avgpool());
CHECK(cruise_model_parameter.has_obs_feature_fc());
CHECK(cruise_model_parameter.has_classify());
CHECK(cruise_model_parameter.has_regress());
cruise_model_parameter_.CopyFrom(cruise_model_parameter);
return true;
}
} // namespace network
......
......@@ -30,6 +30,14 @@ namespace network {
class CruiseModel : public NetModel {
public:
/**
* @brief Load cruise network parameters from a protobuf message
* @param CruiseModelParameter is a protobuf message
* @return True if successfully loaded, otherwise False
*/
bool LoadModel(
const apollo::prediction::CruiseModelParameter& cruise_model_parameter);
/**
* @brief Compute the model output from inputs according to a defined layers'
* flow
......@@ -38,6 +46,9 @@ class CruiseModel : public NetModel {
*/
void Run(const std::vector<Eigen::MatrixXf>& inputs,
Eigen::MatrixXf* output) const override;
private:
CruiseModelParameter cruise_model_parameter_;
};
} // namespace network
......
......@@ -6,7 +6,7 @@ import "modules/prediction/proto/network_layers.proto";
// Intermediate building blocks:
message LaneFeatureConv {
message LaneFeatureConvParameter {
optional Conv1dParameter conv1d_0 = 1;
optional ActivationParameter activation_1 = 2;
optional Conv1dParameter conv1d_2 = 3;
......@@ -14,14 +14,14 @@ message LaneFeatureConv {
optional Conv1dParameter conv1d_4 = 5;
}
message ObsFeatureFC {
message ObsFeatureFCParameter {
optional DenseParameter linear_0 = 1;
optional ActivationParameter activation_1 = 2;
optional DenseParameter linear_3 = 3;
optional ActivationParameter activation_4 = 4;
}
message Classify {
message ClassifyParameter {
optional DenseParameter linear_0 = 1;
optional ActivationParameter activation_1 = 2;
optional DenseParameter linear_3 = 3;
......@@ -32,7 +32,7 @@ message Classify {
optional ActivationParameter activation_10 = 8;
}
message Regress {
message RegressParameter {
optional DenseParameter linear_0 = 1;
optional ActivationParameter activation_1 = 2;
optional DenseParameter linear_3 = 3;
......@@ -46,11 +46,11 @@ message Regress {
// Final model
message CruiseModel {
optional LaneFeatureConv lane_feature_conv = 1;
message CruiseModelParameter {
optional LaneFeatureConvParameter lane_feature_conv = 1;
optional MaxPool1dParameter lane_feature_maxpool = 2;
optional AvgPool1dParameter lane_feature_avgpool = 3;
optional ObsFeatureFC obs_feature_fc = 5;
optional Classify classify = 6;
optional Regress regress = 7;
optional ObsFeatureFCParameter obs_feature_fc = 5;
optional ClassifyParameter classify = 6;
optional RegressParameter regress = 7;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册