提交 d452c9d2 编写于 作者: K kechxu 提交者: Jiangtao Hu

Prediction: load layers from pb

上级 88864fe4
......@@ -26,6 +26,9 @@ using apollo::prediction::CruiseModelParameter;
void CruiseModel::Run(const std::vector<Eigen::MatrixXf>& inputs,
Eigen::MatrixXf* output) const {
// TODO(kechxu) implement
// inputs = {lane_feature, obs_feature}
CHECK_EQ(inputs.size(), 2);
// Step 1: Run lane feature conv 1d
// Step 2: Run lane feature max pool 1d
......@@ -52,7 +55,53 @@ bool CruiseModel::LoadModel(
CHECK(cruise_model_parameter.has_classify());
CHECK(cruise_model_parameter.has_regress());
cruise_model_parameter_.CopyFrom(cruise_model_parameter);
// Load LaneFeatureConvParameter
const auto& lane_conv1d_param = cruise_model_parameter.lane_feature_conv();
lane_conv1d_0_.Load(lane_conv1d_param.conv1d_0());
lane_activation_1_.Load(lane_conv1d_param.activation_1());
lane_conv1d_2_.Load(lane_conv1d_param.conv1d_2());
lane_activation_3_.Load(lane_conv1d_param.activation_3());
lane_conv1d_4_.Load(lane_conv1d_param.conv1d_4());
// Load MaxPool1dParameter
const auto& lane_maxpool1d_param =
cruise_model_parameter.lane_feature_maxpool();
lane_maxpool1d_.Load(lane_maxpool1d_param);
// Load AvgPool1dParameter
const auto& lane_avgpool1d_param =
cruise_model_parameter.lane_feature_avgpool();
lane_avgpool1d_.Load(lane_avgpool1d_param);
// Load ObsFeatureFCParameter
const auto& obs_fc_param = cruise_model_parameter.obs_feature_fc();
obs_linear_0_.Load(obs_fc_param.linear_0());
obs_activation_1_.Load(obs_fc_param.activation_1());
obs_linear_3_.Load(obs_fc_param.linear_3());
obs_activation_4_.Load(obs_fc_param.activation_4());
// Load ClassifyParameter
const auto& classify_param = cruise_model_parameter.classify();
classify_linear_0_.Load(classify_param.linear_0());
classify_activation_1_.Load(classify_param.activation_1());
classify_linear_3_.Load(classify_param.linear_3());
classify_activation_4_.Load(classify_param.activation_4());
classify_linear_6_.Load(classify_param.linear_6());
classify_activation_7_.Load(classify_param.activation_7());
classify_linear_9_.Load(classify_param.linear_9());
classify_activation_10_.Load(classify_param.activation_10());
// Load RegressParameter
const auto& regress_param = cruise_model_parameter.regress();
regress_linear_0_.Load(regress_param.linear_0());
regress_activation_1_.Load(regress_param.activation_1());
regress_linear_3_.Load(regress_param.linear_3());
regress_activation_4_.Load(regress_param.activation_4());
regress_linear_6_.Load(regress_param.linear_6());
regress_activation_7_.Load(regress_param.activation_7());
regress_linear_9_.Load(regress_param.linear_9());
regress_activation_10_.Load(regress_param.activation_10());
return true;
}
......
......@@ -23,6 +23,7 @@
#include "cyber/common/macros.h"
#include "modules/prediction/proto/cruise_model.pb.h"
#include "modules/prediction/network/net_model.h"
#include "modules/prediction/network/net_layer.h"
namespace apollo {
namespace prediction {
......@@ -48,7 +49,44 @@ class CruiseModel : public NetModel {
Eigen::MatrixXf* output) const override;
private:
CruiseModelParameter cruise_model_parameter_;
// LaneFeatureConvParameter
Conv1d lane_conv1d_0_;
Activation lane_activation_1_;
Conv1d lane_conv1d_2_;
Activation lane_activation_3_;
Conv1d lane_conv1d_4_;
// MaxPool1dParameter
MaxPool1d lane_maxpool1d_;
// AvgPool1dParameter
AvgPool1d lane_avgpool1d_;
// ObsFeatureFCParameter
Dense obs_linear_0_;
Activation obs_activation_1_;
Dense obs_linear_3_;
Activation obs_activation_4_;
// ClassifyParameter
Dense classify_linear_0_;
Activation classify_activation_1_;
Dense classify_linear_3_;
Activation classify_activation_4_;
Dense classify_linear_6_;
Activation classify_activation_7_;
Dense classify_linear_9_;
Activation classify_activation_10_;
// RegressParameter
Dense regress_linear_0_;
Activation regress_activation_1_;
Dense regress_linear_3_;
Activation regress_activation_4_;
Dense regress_linear_6_;
Activation regress_activation_7_;
Dense regress_linear_9_;
Activation regress_activation_10_;
};
} // namespace network
......
......@@ -33,6 +33,7 @@ using apollo::prediction::DenseParameter;
using apollo::prediction::LayerParameter;
using apollo::prediction::Conv1dParameter;
using apollo::prediction::MaxPool1dParameter;
using apollo::prediction::ActivationParameter;
bool Layer::Load(const LayerParameter& layer_pb) {
if (!layer_pb.has_name()) {
......@@ -56,6 +57,10 @@ bool Dense::Load(const LayerParameter& layer_pb) {
return false;
}
DenseParameter dense_pb = layer_pb.dense();
return Load(dense_pb);
}
bool Dense::Load(const DenseParameter& dense_pb) {
if (!dense_pb.has_weights() || !LoadTensor(dense_pb.weights(), &weights_)) {
AERROR << "Fail to Load weights!";
return false;
......@@ -250,6 +255,15 @@ bool Activation::Load(const LayerParameter& layer_pb) {
return true;
}
bool Activation::Load(const ActivationParameter& activation_pb) {
if (!activation_pb.has_activation()) {
kactivation_ = serialize_to_function("linear");
} else {
kactivation_ = serialize_to_function(activation_pb.activation());
}
return true;
}
void Activation::Run(const std::vector<Eigen::MatrixXf>& inputs,
Eigen::MatrixXf* output) {
CHECK_EQ(inputs.size(), 1);
......
......@@ -118,6 +118,13 @@ class Dense : public Layer {
*/
bool Load(const apollo::prediction::LayerParameter& layer_pb) override;
/**
* @brief Load the dense layer parameter from a pb message
* @param A pb message contains the parameters
* @return True is loaded successively, otherwise False
*/
bool Load(const apollo::prediction::DenseParameter& layer_pb);
/**
* @brief Compute the layer output from inputs
* @param Inputs to a network layer
......@@ -258,6 +265,13 @@ class Activation : public Layer {
*/
bool Load(const apollo::prediction::LayerParameter& layer_pb) override;
/**
* @brief Load the parameter from a pb message
* @param A pb message contains the parameters
* @return True is loaded successively, otherwise False
*/
bool Load(const apollo::prediction::ActivationParameter& activation_pb);
/**
* @brief Compute the layer output from inputs
* @param Inputs to a network layer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册