提交 3b35cb85 编写于 作者: K kechxu 提交者: Calvin Miao

implement computing probabilites for mlp model

上级 eb4fd44e
......@@ -12,4 +12,11 @@ cc_library(
],
)
cc_library(
name = "prediction_util",
srcs = ["prediction_util.cc"],
hdrs = ["prediction_util.h"],
deps = [],
)
cpplint()
/******************************************************************************
* Copyright 2017 The Apollo Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*****************************************************************************/
#include <cmath>
#include "modules/prediction/common/prediction_util.h"
namespace apollo {
namespace prediction {
namespace util {
double Normalize(const double value, const double mean, const double std) {
double eps = 1e-10;
return (value - mean) / (std + eps);
}
double Sigmoid(const double value) {
return 1 / (1 + std::exp(-1.0 * value));
}
double Relu(const double value) {
return (value > 0.0) ? value : 0.0;
}
} // namespace util
} // namespace prediction
} // namespace apollo
/******************************************************************************
* Copyright 2017 The Apollo Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*****************************************************************************/
#ifndef MODULES_PREDICTION_COMMON_PREDICTION_UTIL_H_
#define MODULES_PREDICTION_COMMON_PREDICTION_UTIL_H_
namespace apollo {
namespace prediction {
namespace util {
double Normalize(const double value, const double mean, const double std);
double Sigmoid(const double value);
double Relu(const double value);
} // namespace util
} // namespace prediction
} // namespace apollo
#endif // MODULES_PREDICTION_COMMON_PREDICTION_UTIL_H_
......@@ -17,6 +17,7 @@ cc_library(
"//modules/prediction/common:prediction_common",
"//modules/common/math:math_utils",
"//modules/prediction/proto:fnn_vehicle_model_proto",
"//modules/prediction/common:prediction_util",
],
)
......
......@@ -15,10 +15,12 @@
*****************************************************************************/
#include <cmath>
#include <fstream>
#include "modules/prediction/evaluator/vehicle/mlp_evaluator.h"
#include "modules/prediction/common/prediction_gflags.h"
#include "modules/common/math/math_utils.h"
#include "modules/prediction/common/prediction_util.h"
namespace apollo {
namespace prediction {
......@@ -254,12 +256,76 @@ void MLPEvaluator::SetLaneFeatureValues(Obstacle* obstacle_ptr,
}
void MLPEvaluator::LoadModel(const std::string& model_file) {
// TODO(kechxu) implement
model_ptr_.reset(new FnnVehicleModel());
CHECK(model_ptr_ != nullptr);
std::fstream file_stream(model_file, std::ios::in | std::ios::binary);
if (!file_stream.good()) {
AERROR << "Unable to open the model file: " << model_file << ".";
return;
}
if (!model_ptr_->ParseFromIstream(&file_stream)) {
AERROR << "Unable to load the model file: " << model_file << ".";
return;
}
ADEBUG << "Succeeded in loading the model file: " << model_file << ".";
}
double MLPEvaluator::ComputeProbability() {
// TODO(kechxu) implement
return 0.0;
CHECK(model_ptr_.get() != nullptr);
double probability = 0.0;
if (model_ptr_->dim_input() != static_cast<int>(feature_values_.size())) {
AERROR << "Model feature size not consistent with model proto definition.";
return probability;
}
std::vector<double> layer_input;
layer_input.reserve(model_ptr_->dim_input());
std::vector<double> layer_output;
// normalization
for (int i = 0; i < model_ptr_->dim_input(); ++i) {
double mean = model_ptr_->samples_mean().columns(i);
double std = model_ptr_->samples_std().columns(i);
layer_input.push_back(
apollo::prediction::util::Normalize(feature_values_[i], mean, std));
}
for (int i = 0; i < model_ptr_->num_layer(); ++i) {
if (i > 0) {
layer_input.clear();
layer_output.swap(layer_output);
}
const Layer& layer = model_ptr_->layer(i);
for (int col = 0; col < layer.layer_output_dim(); ++col) {
double neuron_output = layer.layer_bias().columns(col);
for (int row = 0; row < layer.layer_input_dim(); ++row) {
double weight = layer.layer_input_weight().rows(row).columns(col);
neuron_output += (layer_input[row] * weight);
}
if (layer.layer_activation_type() == "relu") {
neuron_output = apollo::prediction::util::Relu(neuron_output);
} else if (layer.layer_activation_type() == "sigmoid") {
neuron_output = apollo::prediction::util::Sigmoid(neuron_output);
} else if (layer.layer_activation_type() == "tanh") {
neuron_output = std::tanh(neuron_output);
} else {
LOG(ERROR) << "Undefined activation func: "
<< layer.layer_activation_type()
<< ", and default sigmoid will be used instead.";
neuron_output = apollo::prediction::util::Sigmoid(neuron_output);
}
layer_output.push_back(neuron_output);
}
}
if (layer_output.size() != 1) {
AERROR << "Model output layer has incorrect # outputs: "
<< layer_output.size();
} else {
probability = layer_output[0];
}
return probability;
}
} // namespace prediction
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册