提交 50ab2780 编写于 作者: K kechxu 提交者: Jiangtao Hu

Prediction: implement load tensor 3d'

上级 9a1b5df0
......@@ -139,13 +139,13 @@ void Conv1d::Run(const std::vector<Eigen::MatrixXf>& inputs,
output->resize(output_num_row, output_num_col);
for (int i = 0; i < output_num_col; ++i) {
for (int j = 0; j + kernel_size < inputs[0].cols(); j += stride_) {
float output_i_j = 0.0;
float output_i_j_unbiased = 0.0;
for (int p = 0; p < inputs[0].rows(); ++p) {
for (int q = j; q < j + kernel_size; ++q) {
output_i_j += inputs[0](p, q) * kernel_[i](p, q - j);
output_i_j_unbiased += inputs[0](p, q) * kernel_[i](p, q - j);
}
}
(*output)(i, j) = output_i_j;
(*output)(i, j) = output_i_j_unbiased + bias_(i);
}
}
}
......
......@@ -17,6 +17,7 @@
#include "modules/prediction/network/net_util.h"
#include <unordered_map>
#include <cmath>
#include "cyber/common/log.h"
......@@ -24,9 +25,13 @@ namespace apollo {
namespace prediction {
namespace network {
float sigmoid(const float x) { return 1 / (1 + exp(-x)); }
float sigmoid(const float x) {
return static_cast<float>(1.0 / (1.0 + std::exp(-x)));
}
float tanh(const float x) { return std::tanh(x); }
float tanh(const float x) {
return static_cast<float>(std::tanh(x));
}
float linear(const float x) { return x; }
......@@ -95,7 +100,25 @@ bool LoadTensor(const TensorParameter& tensor_pb,
AERROR << "Fail to load the necessary fields!";
return false;
}
// TODO(kechxu) implement
int num_depth = tensor_pb.shape(0);
int num_row = tensor_pb.shape(1);
int num_col = tensor_pb.shape(2);
CHECK_EQ(tensor_pb.data_size(), num_depth * num_row * num_col);
tensor3d->clear();
tensor3d->resize(num_depth);
for (int k = 0; k < num_depth; ++k) {
tensor3d->operator[](k).resize(num_row, num_col);
}
int tensor_pb_index = 0;
for (int k = 0; k < num_depth; ++k) {
for (int i = 0; i < num_row; ++i) {
for (int j = 0; j < num_col; ++j) {
tensor3d->operator[](k)(i, j) = tensor_pb.data(tensor_pb_index);
++tensor_pb_index;
}
}
}
CHECK_EQ(tensor_pb_index, num_depth * num_row * num_col);
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册