提交 42d5e5e2 编写于 作者: K kechxu 提交者: Calvin Miao

Prediction: implement records_to_data_for_learning

上级 c6012dba
......@@ -16,7 +16,6 @@
#include "modules/prediction/common/feature_output.h"
#include <string>
#include <vector>
#include "modules/common/util/file.h"
......@@ -59,7 +58,7 @@ void FeatureOutput::Insert(const Feature& feature) {
void FeatureOutput::InsertDataForLearning(
const Feature& feature, const std::vector<double>& feature_values,
const std::vector<double>& labels) {
const std::string& category) {
DataForLearning* data_for_learning =
list_data_for_learning_.add_data_for_learning();
data_for_learning->set_id(feature.id());
......@@ -67,9 +66,7 @@ void FeatureOutput::InsertDataForLearning(
for (size_t i = 0; i < feature_values.size(); ++i) {
data_for_learning->add_features_for_learning(feature_values[i]);
}
for (size_t i = 0; i < labels.size(); ++i) {
data_for_learning->add_labels(labels[i]);
}
data_for_learning->set_category(category);
}
void FeatureOutput::Write() {
......
......@@ -17,6 +17,7 @@
#pragma once
#include <vector>
#include <string>
#include "modules/prediction/proto/offline_features.pb.h"
......@@ -58,7 +59,7 @@ class FeatureOutput {
*/
static void InsertDataForLearning(
const Feature& feature, const std::vector<double>& feature_values,
const std::vector<double>& labels);
const std::string& category);
/**
* @brief Write features to a file
......
......@@ -74,7 +74,7 @@ void LaneScanningEvaluator::Evaluate(
std::vector<double> labels = {0.0};
if (FLAGS_prediction_offline_dataforlearning) {
FeatureOutput::InsertDataForLearning(*latest_feature_ptr, feature_values,
labels);
"cruise");
ADEBUG << "Save extracted features for learning locally.";
} else {
// TODO(jiacheng): once the model is trained, implement this online part.
......
......@@ -29,4 +29,12 @@ cc_binary(
],
)
cc_binary(
name = "records_to_data_for_learning",
srcs = ["records_to_data_for_learning.cc"],
deps = [
"//modules/prediction/common:message_process",
],
)
cpplint()
/******************************************************************************
* Copyright 2019 The Apollo Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*****************************************************************************/
#include "cyber/common/file.h"
#include "modules/common/util/string_util.h"
#include "modules/prediction/common/feature_output.h"
#include "modules/prediction/common/message_process.h"
#include "modules/prediction/common/prediction_system_gflags.h"
#include "modules/prediction/util/data_extraction.h"
namespace apollo {
namespace prediction {
void GenerateDataForLearning() {
if (!FeatureOutput::Ready()) {
AERROR << "Feature output is not ready.";
return;
}
if (FLAGS_prediction_offline_bags.empty()) {
return;
}
std::vector<std::string> inputs;
common::util::Split(FLAGS_prediction_offline_bags, ':', &inputs);
for (const auto& input : inputs) {
std::vector<std::string> offline_bags;
GetRecordFileNames(boost::filesystem::path(input), &offline_bags);
std::sort(offline_bags.begin(), offline_bags.end());
AINFO << "For input " << input << ", found " << offline_bags.size()
<< " rosbags to process";
for (std::size_t i = 0; i < offline_bags.size(); ++i) {
AINFO << "\tProcessing: [ " << i << " / " << offline_bags.size()
<< " ]: " << offline_bags[i];
MessageProcess::ProcessOfflineData(offline_bags[i]);
}
}
FeatureOutput::Close();
}
} // namespace prediction
} // namespace apollo
int main(int argc, char *argv[]) {
google::ParseCommandLineFlags(&argc, &argv, true);
apollo::prediction::GenerateDataForLearning();
return 0;
}
......@@ -3,7 +3,6 @@ syntax = "proto2";
package apollo.prediction;
import "modules/common/proto/geometry.proto";
import "modules/common/proto/pnc_point.proto";
import "modules/perception/proto/perception_obstacle.proto";
import "modules/prediction/proto/lane_graph.proto";
import "modules/prediction/proto/prediction_point.proto";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册