Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Pinoxchio
apollo
提交
d452c9d2
A
apollo
项目概览
Pinoxchio
/
apollo
与 Fork 源项目一致
从无法访问的项目Fork
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
apollo
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d452c9d2
编写于
11月 07, 2018
作者:
K
kechxu
提交者:
Jiangtao Hu
12月 13, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Prediction: load layers from pb
上级
88864fe4
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
117 addition
and
2 deletion
+117
-2
modules/prediction/network/cruise_model/cruise_model.cc
modules/prediction/network/cruise_model/cruise_model.cc
+50
-1
modules/prediction/network/cruise_model/cruise_model.h
modules/prediction/network/cruise_model/cruise_model.h
+39
-1
modules/prediction/network/net_layer.cc
modules/prediction/network/net_layer.cc
+14
-0
modules/prediction/network/net_layer.h
modules/prediction/network/net_layer.h
+14
-0
未找到文件。
modules/prediction/network/cruise_model/cruise_model.cc
浏览文件 @
d452c9d2
...
...
@@ -26,6 +26,9 @@ using apollo::prediction::CruiseModelParameter;
void
CruiseModel
::
Run
(
const
std
::
vector
<
Eigen
::
MatrixXf
>&
inputs
,
Eigen
::
MatrixXf
*
output
)
const
{
// TODO(kechxu) implement
// inputs = {lane_feature, obs_feature}
CHECK_EQ
(
inputs
.
size
(),
2
);
// Step 1: Run lane feature conv 1d
// Step 2: Run lane feature max pool 1d
...
...
@@ -52,7 +55,53 @@ bool CruiseModel::LoadModel(
CHECK
(
cruise_model_parameter
.
has_classify
());
CHECK
(
cruise_model_parameter
.
has_regress
());
cruise_model_parameter_
.
CopyFrom
(
cruise_model_parameter
);
// Load LaneFeatureConvParameter
const
auto
&
lane_conv1d_param
=
cruise_model_parameter
.
lane_feature_conv
();
lane_conv1d_0_
.
Load
(
lane_conv1d_param
.
conv1d_0
());
lane_activation_1_
.
Load
(
lane_conv1d_param
.
activation_1
());
lane_conv1d_2_
.
Load
(
lane_conv1d_param
.
conv1d_2
());
lane_activation_3_
.
Load
(
lane_conv1d_param
.
activation_3
());
lane_conv1d_4_
.
Load
(
lane_conv1d_param
.
conv1d_4
());
// Load MaxPool1dParameter
const
auto
&
lane_maxpool1d_param
=
cruise_model_parameter
.
lane_feature_maxpool
();
lane_maxpool1d_
.
Load
(
lane_maxpool1d_param
);
// Load AvgPool1dParameter
const
auto
&
lane_avgpool1d_param
=
cruise_model_parameter
.
lane_feature_avgpool
();
lane_avgpool1d_
.
Load
(
lane_avgpool1d_param
);
// Load ObsFeatureFCParameter
const
auto
&
obs_fc_param
=
cruise_model_parameter
.
obs_feature_fc
();
obs_linear_0_
.
Load
(
obs_fc_param
.
linear_0
());
obs_activation_1_
.
Load
(
obs_fc_param
.
activation_1
());
obs_linear_3_
.
Load
(
obs_fc_param
.
linear_3
());
obs_activation_4_
.
Load
(
obs_fc_param
.
activation_4
());
// Load ClassifyParameter
const
auto
&
classify_param
=
cruise_model_parameter
.
classify
();
classify_linear_0_
.
Load
(
classify_param
.
linear_0
());
classify_activation_1_
.
Load
(
classify_param
.
activation_1
());
classify_linear_3_
.
Load
(
classify_param
.
linear_3
());
classify_activation_4_
.
Load
(
classify_param
.
activation_4
());
classify_linear_6_
.
Load
(
classify_param
.
linear_6
());
classify_activation_7_
.
Load
(
classify_param
.
activation_7
());
classify_linear_9_
.
Load
(
classify_param
.
linear_9
());
classify_activation_10_
.
Load
(
classify_param
.
activation_10
());
// Load RegressParameter
const
auto
&
regress_param
=
cruise_model_parameter
.
regress
();
regress_linear_0_
.
Load
(
regress_param
.
linear_0
());
regress_activation_1_
.
Load
(
regress_param
.
activation_1
());
regress_linear_3_
.
Load
(
regress_param
.
linear_3
());
regress_activation_4_
.
Load
(
regress_param
.
activation_4
());
regress_linear_6_
.
Load
(
regress_param
.
linear_6
());
regress_activation_7_
.
Load
(
regress_param
.
activation_7
());
regress_linear_9_
.
Load
(
regress_param
.
linear_9
());
regress_activation_10_
.
Load
(
regress_param
.
activation_10
());
return
true
;
}
...
...
modules/prediction/network/cruise_model/cruise_model.h
浏览文件 @
d452c9d2
...
...
@@ -23,6 +23,7 @@
#include "cyber/common/macros.h"
#include "modules/prediction/proto/cruise_model.pb.h"
#include "modules/prediction/network/net_model.h"
#include "modules/prediction/network/net_layer.h"
namespace
apollo
{
namespace
prediction
{
...
...
@@ -48,7 +49,44 @@ class CruiseModel : public NetModel {
Eigen
::
MatrixXf
*
output
)
const
override
;
private:
CruiseModelParameter
cruise_model_parameter_
;
// LaneFeatureConvParameter
Conv1d
lane_conv1d_0_
;
Activation
lane_activation_1_
;
Conv1d
lane_conv1d_2_
;
Activation
lane_activation_3_
;
Conv1d
lane_conv1d_4_
;
// MaxPool1dParameter
MaxPool1d
lane_maxpool1d_
;
// AvgPool1dParameter
AvgPool1d
lane_avgpool1d_
;
// ObsFeatureFCParameter
Dense
obs_linear_0_
;
Activation
obs_activation_1_
;
Dense
obs_linear_3_
;
Activation
obs_activation_4_
;
// ClassifyParameter
Dense
classify_linear_0_
;
Activation
classify_activation_1_
;
Dense
classify_linear_3_
;
Activation
classify_activation_4_
;
Dense
classify_linear_6_
;
Activation
classify_activation_7_
;
Dense
classify_linear_9_
;
Activation
classify_activation_10_
;
// RegressParameter
Dense
regress_linear_0_
;
Activation
regress_activation_1_
;
Dense
regress_linear_3_
;
Activation
regress_activation_4_
;
Dense
regress_linear_6_
;
Activation
regress_activation_7_
;
Dense
regress_linear_9_
;
Activation
regress_activation_10_
;
};
}
// namespace network
...
...
modules/prediction/network/net_layer.cc
浏览文件 @
d452c9d2
...
...
@@ -33,6 +33,7 @@ using apollo::prediction::DenseParameter;
using
apollo
::
prediction
::
LayerParameter
;
using
apollo
::
prediction
::
Conv1dParameter
;
using
apollo
::
prediction
::
MaxPool1dParameter
;
using
apollo
::
prediction
::
ActivationParameter
;
bool
Layer
::
Load
(
const
LayerParameter
&
layer_pb
)
{
if
(
!
layer_pb
.
has_name
())
{
...
...
@@ -56,6 +57,10 @@ bool Dense::Load(const LayerParameter& layer_pb) {
return
false
;
}
DenseParameter
dense_pb
=
layer_pb
.
dense
();
return
Load
(
dense_pb
);
}
bool
Dense
::
Load
(
const
DenseParameter
&
dense_pb
)
{
if
(
!
dense_pb
.
has_weights
()
||
!
LoadTensor
(
dense_pb
.
weights
(),
&
weights_
))
{
AERROR
<<
"Fail to Load weights!"
;
return
false
;
...
...
@@ -250,6 +255,15 @@ bool Activation::Load(const LayerParameter& layer_pb) {
return
true
;
}
bool
Activation
::
Load
(
const
ActivationParameter
&
activation_pb
)
{
if
(
!
activation_pb
.
has_activation
())
{
kactivation_
=
serialize_to_function
(
"linear"
);
}
else
{
kactivation_
=
serialize_to_function
(
activation_pb
.
activation
());
}
return
true
;
}
void
Activation
::
Run
(
const
std
::
vector
<
Eigen
::
MatrixXf
>&
inputs
,
Eigen
::
MatrixXf
*
output
)
{
CHECK_EQ
(
inputs
.
size
(),
1
);
...
...
modules/prediction/network/net_layer.h
浏览文件 @
d452c9d2
...
...
@@ -118,6 +118,13 @@ class Dense : public Layer {
*/
bool
Load
(
const
apollo
::
prediction
::
LayerParameter
&
layer_pb
)
override
;
/**
* @brief Load the dense layer parameter from a pb message
* @param A pb message contains the parameters
* @return True is loaded successively, otherwise False
*/
bool
Load
(
const
apollo
::
prediction
::
DenseParameter
&
layer_pb
);
/**
* @brief Compute the layer output from inputs
* @param Inputs to a network layer
...
...
@@ -258,6 +265,13 @@ class Activation : public Layer {
*/
bool
Load
(
const
apollo
::
prediction
::
LayerParameter
&
layer_pb
)
override
;
/**
* @brief Load the parameter from a pb message
* @param A pb message contains the parameters
* @return True is loaded successively, otherwise False
*/
bool
Load
(
const
apollo
::
prediction
::
ActivationParameter
&
activation_pb
);
/**
* @brief Compute the layer output from inputs
* @param Inputs to a network layer
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录