Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Pinoxchio
apollo
提交
9a704234
A
apollo
项目概览
Pinoxchio
/
apollo
与 Fork 源项目一致
从无法访问的项目Fork
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
apollo
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
9a704234
编写于
7月 20, 2017
作者:
C
Calvin Miao
浏览文件
操作
浏览文件
下载
差异文件
Resolved planner conflicts
上级
3a630e5f
3b35cb85
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
346 addition
and
65 deletion
+346
-65
modules/planning/BUILD
modules/planning/BUILD
+10
-9
modules/planning/common/BUILD
modules/planning/common/BUILD
+1
-1
modules/planning/planner/BUILD
modules/planning/planner/BUILD
+3
-31
modules/planning/planner/em/BUILD
modules/planning/planner/em/BUILD
+27
-0
modules/planning/planner/em/em_planner.cc
modules/planning/planner/em/em_planner.cc
+86
-0
modules/planning/planner/em/em_planner.h
modules/planning/planner/em/em_planner.h
+7
-6
modules/planning/planner/planner.h
modules/planning/planner/planner.h
+3
-3
modules/planning/planner/rtk/BUILD
modules/planning/planner/rtk/BUILD
+43
-0
modules/planning/planner/rtk/rtk_replay_planner.cc
modules/planning/planner/rtk/rtk_replay_planner.cc
+2
-2
modules/planning/planner/rtk/rtk_replay_planner.h
modules/planning/planner/rtk/rtk_replay_planner.h
+3
-3
modules/planning/planner/rtk/rtk_replay_planner_test.cc
modules/planning/planner/rtk/rtk_replay_planner_test.cc
+3
-3
modules/planning/planning.cc
modules/planning/planning.cc
+6
-3
modules/planning/proxy/BUILD
modules/planning/proxy/BUILD
+1
-1
modules/prediction/common/BUILD
modules/prediction/common/BUILD
+7
-0
modules/prediction/common/prediction_util.cc
modules/prediction/common/prediction_util.cc
+40
-0
modules/prediction/common/prediction_util.h
modules/prediction/common/prediction_util.h
+34
-0
modules/prediction/evaluator/vehicle/BUILD
modules/prediction/evaluator/vehicle/BUILD
+1
-0
modules/prediction/evaluator/vehicle/mlp_evaluator.cc
modules/prediction/evaluator/vehicle/mlp_evaluator.cc
+69
-3
未找到文件。
modules/planning/BUILD
浏览文件 @
9a704234
...
...
@@ -3,24 +3,25 @@ load("//tools:cpplint.bzl", "cpplint")
package
(
default_visibility
=
[
"//visibility:public"
])
cc_library
(
name
=
"
lib_planning
"
,
name
=
"
planning_impl
"
,
srcs
=
[
"planning.cc"
,
],
hdrs
=
glob
(
[
"
*
.h"
,
]
)
,
hdrs
=
[
"
planning
.h"
,
],
deps
=
[
"//modules/common"
,
"//modules/common:apollo_app"
,
"//modules/common:log"
,
"//modules/common/adapters:adapter_manager"
,
"//modules/common:apollo_app"
,
"//modules/common/proto:path_point_proto"
,
"//modules/common/vehicle_state"
,
"//modules/decision/proto:decision_proto"
,
"//modules/perception/proto:perception_proto"
,
"//modules/planning/common:lib_planning_common"
,
"//modules/planning/planner:lib_planner"
,
"//modules/planning/common:planning_common"
,
"//modules/planning/planner/em:em_planner"
,
"//modules/planning/planner/rtk:rtk_planner"
,
"//modules/planning/proto:planning_proto"
,
"//modules/prediction/proto:prediction_proto"
,
"@ros//:ros_common"
,
...
...
@@ -31,7 +32,7 @@ cc_binary(
name
=
"planning"
,
srcs
=
[
"main.cc"
],
deps
=
[
":
lib_planning
"
,
":
planning_impl
"
,
"//external:gflags"
,
],
)
...
...
@@ -44,7 +45,7 @@ cc_test(
],
data
=
[
"//modules/planning:planning_testdata"
],
deps
=
[
":
lib_planning
"
,
":
planning_impl
"
,
"//modules/common:log"
,
"//modules/common/time"
,
"//modules/common/util"
,
...
...
modules/planning/common/BUILD
浏览文件 @
9a704234
...
...
@@ -168,7 +168,7 @@ cc_library(
)
cc_library
(
name
=
"
lib_
planning_common"
,
name
=
"planning_common"
,
deps
=
[
"@eigen//:eigen"
,
":frame"
,
...
...
modules/planning/planner/BUILD
浏览文件 @
9a704234
...
...
@@ -3,41 +3,13 @@ load("//tools:cpplint.bzl", "cpplint")
package
(
default_visibility
=
[
"//visibility:public"
])
cc_library
(
name
=
"lib_planner"
,
srcs
=
[
"rtk_replay_planner.cc"
,
"em_planner.cc"
,
name
=
"planner"
,
hdrs
=
[
"planner.h"
,
],
hdrs
=
glob
([
"*.h"
,
]),
deps
=
[
"//external:gflags"
,
"//modules/common:log"
,
"//modules/common/proto:path_point_proto"
,
"//modules/common/util"
,
"//modules/common/util:factory"
,
"//modules/common/vehicle_state"
,
"//modules/planning/common:lib_planning_common"
,
"//modules/planning/math/curve1d:quartic_polynomial_curve1d"
,
"@eigen//:eigen"
,
],
)
cc_test
(
name
=
"rtk_replay_planner_test"
,
size
=
"small"
,
srcs
=
[
"rtk_replay_planner_test.cc"
,
],
data
=
[
"//modules/planning:planning_testdata"
],
deps
=
[
":lib_planner"
,
"//modules/common:log"
,
"//modules/common/time"
,
"//modules/common/util"
,
"//modules/planning/common:lib_planning_common"
,
"@gtest//:main"
,
],
)
...
...
modules/planning/planner/em/BUILD
0 → 100644
浏览文件 @
9a704234
load
(
"//tools:cpplint.bzl"
,
"cpplint"
)
package
(
default_visibility
=
[
"//visibility:public"
])
cc_library
(
name
=
"em_planner"
,
srcs
=
[
"em_planner.cc"
,
],
hdrs
=
[
"em_planner.h"
,
],
deps
=
[
"//external:gflags"
,
"//modules/common:log"
,
"//modules/common/proto:path_point_proto"
,
"//modules/common/util"
,
"//modules/common/util:factory"
,
"//modules/common/vehicle_state"
,
"//modules/planning/common:planning_common"
,
"//modules/planning/math/curve1d:quartic_polynomial_curve1d"
,
"//modules/planning/planner"
,
"@eigen//:eigen"
,
],
)
cpplint
()
modules/planning/planner/em_planner.cc
→
modules/planning/planner/em
/em
_planner.cc
浏览文件 @
9a704234
...
...
@@ -14,12 +14,14 @@
* limitations under the License.
*****************************************************************************/
#include "modules/planning/planner/em_planner.h"
#include "modules/planning/planner/em
/em
_planner.h"
#include <fstream>
#include <utility>
#include "modules/common/log.h"
#include "modules/common/util/string_tokenizer.h"
#include "modules/planning/common/data_center.h"
#include "modules/planning/common/planning_gflags.h"
#include "modules/planning/math/curve1d/quartic_polynomial_curve1d.h"
...
...
@@ -31,51 +33,54 @@ using apollo::common::vehicle_state::VehicleState;
EMPlanner
::
EMPlanner
()
{}
bool
EMPlanner
::
Plan
(
const
TrajectoryPoint
&
start_point
,
std
::
vector
<
TrajectoryPoint
>*
discretized_trajectory
)
{
bool
EMPlanner
::
Make
Plan
(
const
TrajectoryPoint
&
start_point
,
std
::
vector
<
TrajectoryPoint
>*
discretized_trajectory
)
{
return
true
;
}
std
::
vector
<
SpeedPoint
>
EMPlanner
::
generate_init_speed_profile
(
const
double
init_v
,
const
double
init_a
)
{
//TODO: this is a dummy simple hot start, need refine later
std
::
array
<
double
,
3
>
start_state
;
// distance 0.0
start_state
[
0
]
=
0.0
;
// start velocity
start_state
[
1
]
=
init_v
;
// start acceleration
start_state
[
2
]
=
init_a
;
std
::
array
<
double
,
2
>
end_state
;
// end state velocity
end_state
[
0
]
=
10.0
;
// end state acceleration
end_state
[
1
]
=
0.0
;
// pre assume the curve time is 8 second, can be change later
QuarticPolynomialCurve1d
speed_curve
(
start_state
,
end_state
,
FLAGS_trajectory_time_length
);
// assume the time resolution is 0.1
std
::
size_t
num_time_steps
=
static_cast
<
std
::
size_t
>
(
FLAGS_trajectory_time_length
/
FLAGS_trajectory_time_resolution
)
+
1
;
std
::
vector
<
SpeedPoint
>
speed_profile
;
speed_profile
.
reserve
(
num_time_steps
);
for
(
std
::
size_t
i
=
0
;
i
<
num_time_steps
;
++
i
)
{
double
t
=
i
*
FLAGS_trajectory_time_resolution
;
double
s
=
speed_curve
.
evaluate
(
0
,
t
);
double
v
=
speed_curve
.
evaluate
(
1
,
t
);
double
a
=
speed_curve
.
evaluate
(
2
,
t
);
double
j
=
speed_curve
.
evaluate
(
3
,
t
);
speed_profile
.
emplace_back
(
s
,
t
,
v
,
a
,
j
);
}
return
std
::
move
(
speed_profile
);
std
::
vector
<
SpeedPoint
>
EMPlanner
::
GenerateInitSpeedProfile
(
const
double
init_v
,
const
double
init_a
)
{
// TODO: this is a dummy simple hot start, need refine later
std
::
array
<
double
,
3
>
start_state
;
// distance 0.0
start_state
[
0
]
=
0.0
;
// start velocity
start_state
[
1
]
=
init_v
;
// start acceleration
start_state
[
2
]
=
init_a
;
std
::
array
<
double
,
2
>
end_state
;
// end state velocity
end_state
[
0
]
=
10.0
;
// end state acceleration
end_state
[
1
]
=
0.0
;
// pre assume the curve time is 8 second, can be change later
QuarticPolynomialCurve1d
speed_curve
(
start_state
,
end_state
,
FLAGS_trajectory_time_length
);
// assume the time resolution is 0.1
std
::
size_t
num_time_steps
=
static_cast
<
std
::
size_t
>
(
FLAGS_trajectory_time_length
/
FLAGS_trajectory_time_resolution
)
+
1
;
std
::
vector
<
SpeedPoint
>
speed_profile
;
speed_profile
.
reserve
(
num_time_steps
);
for
(
std
::
size_t
i
=
0
;
i
<
num_time_steps
;
++
i
)
{
double
t
=
i
*
FLAGS_trajectory_time_resolution
;
double
s
=
speed_curve
.
evaluate
(
0
,
t
);
double
v
=
speed_curve
.
evaluate
(
1
,
t
);
double
a
=
speed_curve
.
evaluate
(
2
,
t
);
double
j
=
speed_curve
.
evaluate
(
3
,
t
);
speed_profile
.
emplace_back
(
s
,
t
,
v
,
a
,
j
);
}
return
std
::
move
(
speed_profile
);
}
}
// namespace planning
}
// nameapace apollo
modules/planning/planner/em_planner.h
→
modules/planning/planner/em
/em
_planner.h
浏览文件 @
9a704234
...
...
@@ -20,8 +20,9 @@
#include <string>
#include <vector>
#include "modules/
planning/planner/planner
.h"
#include "modules/
common/proto/path_point.pb
.h"
#include "modules/planning/common/speed/speed_point.h"
#include "modules/planning/planner/planner.h"
/**
* @namespace apollo::planning
...
...
@@ -53,13 +54,13 @@ class EMPlanner : public Planner {
* @param discretized_trajectory The computed trajectory
* @return true if planning succeeds; false otherwise.
*/
bool
Plan
(
const
apollo
::
common
::
TrajectoryPoint
&
start_point
,
std
::
vector
<
apollo
::
common
::
TrajectoryPoint
>*
trajectory
)
override
;
bool
MakePlan
(
const
apollo
::
common
::
TrajectoryPoint
&
start_point
,
std
::
vector
<
apollo
::
common
::
TrajectoryPoint
>*
trajectory
)
override
;
private:
std
::
vector
<
SpeedPoint
>
generate_init_speed_profile
(
const
double
init_v
,
const
double
init_a
);
std
::
vector
<
SpeedPoint
>
GenerateInitSpeedProfile
(
const
double
init_v
,
const
double
init_a
);
};
}
// namespace planning
...
...
modules/planning/planner/planner.h
浏览文件 @
9a704234
...
...
@@ -53,9 +53,9 @@ class Planner {
* @param discretized_trajectory The computed trajectory
* @return true if planning succeeds; false otherwise.
*/
virtual
bool
Plan
(
const
apollo
::
common
::
TrajectoryPoint
&
start_point
,
std
::
vector
<
apollo
::
common
::
TrajectoryPoint
>
*
discretized_trajectory
)
=
0
;
virtual
bool
Make
Plan
(
const
apollo
::
common
::
TrajectoryPoint
&
start_point
,
std
::
vector
<
apollo
::
common
::
TrajectoryPoint
>
*
discretized_trajectory
)
=
0
;
};
}
// namespace planning
...
...
modules/planning/planner/rtk/BUILD
0 → 100644
浏览文件 @
9a704234
load
(
"//tools:cpplint.bzl"
,
"cpplint"
)
package
(
default_visibility
=
[
"//visibility:public"
])
cc_library
(
name
=
"rtk_planner"
,
srcs
=
[
"rtk_replay_planner.cc"
,
],
hdrs
=
[
"rtk_replay_planner.h"
,
],
deps
=
[
"//external:gflags"
,
"//modules/common:log"
,
"//modules/common/proto:path_point_proto"
,
"//modules/common/util"
,
"//modules/common/util:factory"
,
"//modules/common/vehicle_state"
,
"//modules/planning/common:planning_common"
,
"//modules/planning/math/curve1d:quartic_polynomial_curve1d"
,
"//modules/planning/planner"
,
"@eigen//:eigen"
,
],
)
cc_test
(
name
=
"rtk_replay_planner_test"
,
size
=
"small"
,
srcs
=
[
"rtk_replay_planner_test.cc"
,
],
data
=
[
"//modules/planning:planning_testdata"
],
deps
=
[
":rtk_planner"
,
"//modules/common:log"
,
"//modules/common/time"
,
"//modules/common/util"
,
"@gtest//:main"
,
],
)
cpplint
()
modules/planning/planner/rtk_replay_planner.cc
→
modules/planning/planner/rtk
/rtk
_replay_planner.cc
浏览文件 @
9a704234
...
...
@@ -14,7 +14,7 @@
* limitations under the License.
*****************************************************************************/
#include "modules/planning/planner/rtk_replay_planner.h"
#include "modules/planning/planner/rtk
/rtk
_replay_planner.h"
#include <fstream>
...
...
@@ -32,7 +32,7 @@ RTKReplayPlanner::RTKReplayPlanner() {
ReadTrajectoryFile
(
FLAGS_rtk_trajectory_filename
);
}
bool
RTKReplayPlanner
::
Plan
(
bool
RTKReplayPlanner
::
Make
Plan
(
const
TrajectoryPoint
&
start_point
,
std
::
vector
<
TrajectoryPoint
>*
ptr_discretized_trajectory
)
{
if
(
complete_rtk_trajectory_
.
empty
()
||
complete_rtk_trajectory_
.
size
()
<
2
)
{
...
...
modules/planning/planner/rtk_replay_planner.h
→
modules/planning/planner/rtk
/rtk
_replay_planner.h
浏览文件 @
9a704234
...
...
@@ -56,9 +56,9 @@ class RTKReplayPlanner : public Planner {
* @param discretized_trajectory The computed trajectory
* @return true if planning succeeds; false otherwise.
*/
bool
Plan
(
const
apollo
::
common
::
TrajectoryPoint
&
start_point
,
std
::
vector
<
apollo
::
common
::
TrajectoryPoint
>
*
ptr_trajectory
)
override
;
bool
Make
Plan
(
const
apollo
::
common
::
TrajectoryPoint
&
start_point
,
std
::
vector
<
apollo
::
common
::
TrajectoryPoint
>
*
ptr_trajectory
)
override
;
/**
* @brief Read the recorded trajectory file.
...
...
modules/planning/planner/rtk_replay_planner_test.cc
→
modules/planning/planner/rtk
/rtk
_replay_planner_test.cc
浏览文件 @
9a704234
...
...
@@ -14,7 +14,7 @@
* limitations under the License.
*****************************************************************************/
#include "modules/planning/planner/rtk_replay_planner.h"
#include "modules/planning/planner/rtk
/rtk
_replay_planner.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
...
...
@@ -37,7 +37,7 @@ TEST_F(RTKReplayPlannerTest, ComputeTrajectory) {
start_point
.
mutable_path_point
()
->
set_y
(
4140674.76063
);
std
::
vector
<
TrajectoryPoint
>
trajectory
;
bool
planning_succeeded
=
planner
.
Plan
(
start_point
,
&
trajectory
);
bool
planning_succeeded
=
planner
.
Make
Plan
(
start_point
,
&
trajectory
);
EXPECT_TRUE
(
planning_succeeded
);
EXPECT_TRUE
(
!
trajectory
.
empty
());
...
...
@@ -62,7 +62,7 @@ TEST_F(RTKReplayPlannerTest, ErrorTest) {
start_point
.
mutable_path_point
()
->
set_x
(
586385.782842
);
start_point
.
mutable_path_point
()
->
set_y
(
4140674.76063
);
std
::
vector
<
TrajectoryPoint
>
trajectory
;
EXPECT_TRUE
(
!
planner_with_error_csv
.
Plan
(
start_point
,
&
trajectory
));
EXPECT_TRUE
(
!
planner_with_error_csv
.
Make
Plan
(
start_point
,
&
trajectory
));
}
}
// namespace control
...
...
modules/planning/planning.cc
浏览文件 @
9a704234
...
...
@@ -17,7 +17,8 @@
#include "modules/common/adapters/adapter_manager.h"
#include "modules/common/time/time.h"
#include "modules/planning/common/planning_gflags.h"
#include "modules/planning/planner/rtk_replay_planner.h"
#include "modules/planning/planner/em/em_planner.h"
#include "modules/planning/planner/rtk/rtk_replay_planner.h"
#include "modules/planning/planning.h"
#include "modules/planning/planning.h"
...
...
@@ -36,6 +37,8 @@ std::string Planning::Name() const { return "planning"; }
void
Planning
::
RegisterPlanners
()
{
planner_factory_
.
Register
(
PlanningConfig
::
RTK
,
[]()
->
Planner
*
{
return
new
RTKReplayPlanner
();
});
planner_factory_
.
Register
(
PlanningConfig
::
EM
,
[]()
->
Planner
*
{
return
new
EMPlanner
();
});
}
Status
Planning
::
Init
()
{
...
...
@@ -157,7 +160,7 @@ bool Planning::Plan(const common::vehicle_state::VehicleState& vehicle_state,
// planned trajectory from the matched point, the matched point has
// relative time 0.
bool
planning_succeeded
=
planner_
->
Plan
(
matched_point
,
planning_trajectory
);
planner_
->
Make
Plan
(
matched_point
,
planning_trajectory
);
if
(
!
planning_succeeded
)
{
last_trajectory_
.
clear
();
...
...
@@ -187,7 +190,7 @@ bool Planning::Plan(const common::vehicle_state::VehicleState& vehicle_state,
ComputeStartingPointFromVehicleState
(
vehicle_state
,
planning_cycle_time
);
bool
planning_succeeded
=
planner_
->
Plan
(
vehicle_state_point
,
planning_trajectory
);
planner_
->
Make
Plan
(
vehicle_state_point
,
planning_trajectory
);
if
(
!
planning_succeeded
)
{
last_trajectory_
.
clear
();
return
false
;
...
...
modules/planning/proxy/BUILD
浏览文件 @
9a704234
...
...
@@ -33,7 +33,7 @@ cc_library(
"//modules/common/proto:error_code_proto"
,
"//modules/common/proto:path_point_proto"
,
"//modules/common/configs:vehicle_config_helper"
,
"//modules/planning/common:
lib_
planning_common"
,
"//modules/planning/common:planning_common"
,
"//modules/canbus/proto:canbus_proto"
,
"//modules/localization/proto:localization_proto"
,
],
...
...
modules/prediction/common/BUILD
浏览文件 @
9a704234
...
...
@@ -12,4 +12,11 @@ cc_library(
],
)
cc_library
(
name
=
"prediction_util"
,
srcs
=
[
"prediction_util.cc"
],
hdrs
=
[
"prediction_util.h"
],
deps
=
[],
)
cpplint
()
modules/prediction/common/prediction_util.cc
0 → 100644
浏览文件 @
9a704234
/******************************************************************************
* Copyright 2017 The Apollo Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*****************************************************************************/
#include <cmath>
#include "modules/prediction/common/prediction_util.h"
namespace
apollo
{
namespace
prediction
{
namespace
util
{
double
Normalize
(
const
double
value
,
const
double
mean
,
const
double
std
)
{
double
eps
=
1e-10
;
return
(
value
-
mean
)
/
(
std
+
eps
);
}
double
Sigmoid
(
const
double
value
)
{
return
1
/
(
1
+
std
::
exp
(
-
1.0
*
value
));
}
double
Relu
(
const
double
value
)
{
return
(
value
>
0.0
)
?
value
:
0.0
;
}
}
// namespace util
}
// namespace prediction
}
// namespace apollo
modules/prediction/common/prediction_util.h
0 → 100644
浏览文件 @
9a704234
/******************************************************************************
* Copyright 2017 The Apollo Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*****************************************************************************/
#ifndef MODULES_PREDICTION_COMMON_PREDICTION_UTIL_H_
#define MODULES_PREDICTION_COMMON_PREDICTION_UTIL_H_
namespace
apollo
{
namespace
prediction
{
namespace
util
{
double
Normalize
(
const
double
value
,
const
double
mean
,
const
double
std
);
double
Sigmoid
(
const
double
value
);
double
Relu
(
const
double
value
);
}
// namespace util
}
// namespace prediction
}
// namespace apollo
#endif // MODULES_PREDICTION_COMMON_PREDICTION_UTIL_H_
modules/prediction/evaluator/vehicle/BUILD
浏览文件 @
9a704234
...
...
@@ -17,6 +17,7 @@ cc_library(
"//modules/prediction/common:prediction_common"
,
"//modules/common/math:math_utils"
,
"//modules/prediction/proto:fnn_vehicle_model_proto"
,
"//modules/prediction/common:prediction_util"
,
],
)
...
...
modules/prediction/evaluator/vehicle/mlp_evaluator.cc
浏览文件 @
9a704234
...
...
@@ -15,10 +15,12 @@
*****************************************************************************/
#include <cmath>
#include <fstream>
#include "modules/prediction/evaluator/vehicle/mlp_evaluator.h"
#include "modules/prediction/common/prediction_gflags.h"
#include "modules/common/math/math_utils.h"
#include "modules/prediction/common/prediction_util.h"
namespace
apollo
{
namespace
prediction
{
...
...
@@ -254,12 +256,76 @@ void MLPEvaluator::SetLaneFeatureValues(Obstacle* obstacle_ptr,
}
void
MLPEvaluator
::
LoadModel
(
const
std
::
string
&
model_file
)
{
// TODO(kechxu) implement
model_ptr_
.
reset
(
new
FnnVehicleModel
());
CHECK
(
model_ptr_
!=
nullptr
);
std
::
fstream
file_stream
(
model_file
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
if
(
!
file_stream
.
good
())
{
AERROR
<<
"Unable to open the model file: "
<<
model_file
<<
"."
;
return
;
}
if
(
!
model_ptr_
->
ParseFromIstream
(
&
file_stream
))
{
AERROR
<<
"Unable to load the model file: "
<<
model_file
<<
"."
;
return
;
}
ADEBUG
<<
"Succeeded in loading the model file: "
<<
model_file
<<
"."
;
}
double
MLPEvaluator
::
ComputeProbability
()
{
// TODO(kechxu) implement
return
0.0
;
CHECK
(
model_ptr_
.
get
()
!=
nullptr
);
double
probability
=
0.0
;
if
(
model_ptr_
->
dim_input
()
!=
static_cast
<
int
>
(
feature_values_
.
size
()))
{
AERROR
<<
"Model feature size not consistent with model proto definition."
;
return
probability
;
}
std
::
vector
<
double
>
layer_input
;
layer_input
.
reserve
(
model_ptr_
->
dim_input
());
std
::
vector
<
double
>
layer_output
;
// normalization
for
(
int
i
=
0
;
i
<
model_ptr_
->
dim_input
();
++
i
)
{
double
mean
=
model_ptr_
->
samples_mean
().
columns
(
i
);
double
std
=
model_ptr_
->
samples_std
().
columns
(
i
);
layer_input
.
push_back
(
apollo
::
prediction
::
util
::
Normalize
(
feature_values_
[
i
],
mean
,
std
));
}
for
(
int
i
=
0
;
i
<
model_ptr_
->
num_layer
();
++
i
)
{
if
(
i
>
0
)
{
layer_input
.
clear
();
layer_output
.
swap
(
layer_output
);
}
const
Layer
&
layer
=
model_ptr_
->
layer
(
i
);
for
(
int
col
=
0
;
col
<
layer
.
layer_output_dim
();
++
col
)
{
double
neuron_output
=
layer
.
layer_bias
().
columns
(
col
);
for
(
int
row
=
0
;
row
<
layer
.
layer_input_dim
();
++
row
)
{
double
weight
=
layer
.
layer_input_weight
().
rows
(
row
).
columns
(
col
);
neuron_output
+=
(
layer_input
[
row
]
*
weight
);
}
if
(
layer
.
layer_activation_type
()
==
"relu"
)
{
neuron_output
=
apollo
::
prediction
::
util
::
Relu
(
neuron_output
);
}
else
if
(
layer
.
layer_activation_type
()
==
"sigmoid"
)
{
neuron_output
=
apollo
::
prediction
::
util
::
Sigmoid
(
neuron_output
);
}
else
if
(
layer
.
layer_activation_type
()
==
"tanh"
)
{
neuron_output
=
std
::
tanh
(
neuron_output
);
}
else
{
LOG
(
ERROR
)
<<
"Undefined activation func: "
<<
layer
.
layer_activation_type
()
<<
", and default sigmoid will be used instead."
;
neuron_output
=
apollo
::
prediction
::
util
::
Sigmoid
(
neuron_output
);
}
layer_output
.
push_back
(
neuron_output
);
}
}
if
(
layer_output
.
size
()
!=
1
)
{
AERROR
<<
"Model output layer has incorrect # outputs: "
<<
layer_output
.
size
();
}
else
{
probability
=
layer_output
[
0
];
}
return
probability
;
}
}
// namespace prediction
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录