提交 2ef31a63 编写于 作者: C Crysple 提交者: Qi Luo

Audio: add direction transformation code, refactor & fix bugs

上级 17f95252
......@@ -15,12 +15,13 @@
*****************************************************************************/
#include "modules/audio/audio_component.h"
#include "modules/audio/inference/direction_detection.h"
#include "modules/audio/proto/audio_conf.pb.h"
#include "modules/common/proto/geometry.pb.h"
namespace apollo {
namespace audio {
using apollo::common::Point3D;
using apollo::drivers::microphone::config::AudioData;
AudioComponent::~AudioComponent() {}
......@@ -42,21 +43,24 @@ bool AudioComponent::Init() {
audio_conf.topic_conf().localization_topic_name(), nullptr);
audio_writer_ = node_->CreateWriter<AudioDetection>(
audio_conf.topic_conf().audio_detection_topic_name());
respeaker_extrinsics_file_ = audio_conf.respeaker_extrinsics_path();
return true;
}
bool AudioComponent::Proc(const std::shared_ptr<AudioData>& audio_data) {
audio_info_.Insert(audio_data);
AINFO << "Current direction is: "
<< get_direction(
audio_info_.GetSignals(audio_data->microphone_config().chunk()),
audio_data->microphone_config().sample_rate(),
audio_data->microphone_config().mic_distance());
// TODO(all) remove GetSignals() multiple calls
audio_info_.Insert(audio_data);
AudioDetection audio_detection;
*audio_detection.mutable_position() =
direction_detection_.EstimateSoundSource(
audio_info_.GetSignals(audio_data->microphone_config().chunk()),
respeaker_extrinsics_file_,
audio_data->microphone_config().sample_rate(),
audio_data->microphone_config().mic_distance());
auto signals =
audio_info_.GetSignals(audio_data->microphone_config().chunk());
MovingResult moving_result = moving_detection_.Detect(signals);
AudioDetection audio_detection;
audio_detection.set_moving_result(moving_result);
// TODO(all) add header to audio_detection
audio_writer_->Write(audio_detection);
......
......@@ -25,6 +25,7 @@
#include "cyber/component/component.h"
#include "modules/audio/common/audio_info.h"
#include "modules/audio/inference/direction_detection.h"
#include "modules/audio/inference/moving_detection.h"
#include "modules/audio/proto/audio.pb.h"
#include "modules/drivers/microphone/proto/audio.pb.h"
......@@ -59,6 +60,8 @@ class AudioComponent
AudioInfo audio_info_;
MovingDetection moving_detection_;
DirectionDetection direction_detection_;
std::string respeaker_extrinsics_file_;
};
CYBER_REGISTER_COMPONENT(AudioComponent)
......
......@@ -3,3 +3,4 @@ topic_conf {
audio_detection_topic_name: "/apollo/audio_detection"
localization_topic_name: "/apollo/localization/pose"
}
respeaker_extrinsics_path: "/apollo/modules/audio/conf/respeaker_extrinsics.yaml"
\ No newline at end of file
header:
stamp:
secs: 0
nsecs: 0
seq: 0
frame_id: novatel
child_frame_id: microphone
transform:
rotation:
x: 0
y: 0
z: 0
w: 1
translation:
x: 0.0
y: 0.68
z: 0.72
......@@ -19,7 +19,10 @@ cc_library(
hdrs = ["direction_detection.h"],
deps = [
"//cyber",
"//modules/common/proto:geometry_cc_proto",
"//third_party:libtorch",
"@com_github_jbeder_yaml_cpp//:yaml-cpp",
"@eigen",
],
)
......
......@@ -15,6 +15,7 @@
*****************************************************************************/
#include "modules/audio/inference/direction_detection.h"
#include "yaml-cpp/yaml.h"
namespace apollo {
namespace audio {
......@@ -22,10 +23,37 @@ namespace audio {
using torch::indexing::None;
using torch::indexing::Slice;
int get_direction(std::vector<std::vector<double>>&& channels_vec,
const int sample_rate, const int mic_distance) {
DirectionDetection::DirectionDetection() {}
DirectionDetection::~DirectionDetection() {}
Point3D DirectionDetection::EstimateSoundSource(
std::vector<std::vector<double>>&& channels_vec,
const std::string& respeaker_extrinsic_file, const int sample_rate,
const double mic_distance) {
if (!respeaker2imu_ptr_.get()) {
respeaker2imu_ptr_.reset(new Eigen::Matrix4d);
LoadExtrinsics(respeaker_extrinsic_file, respeaker2imu_ptr_.get());
}
const double degree =
EstimateDirection(move(channels_vec), sample_rate, mic_distance);
Eigen::Vector4d source_position(kDistance * sin(degree),
kDistance * cos(degree), 0, 1);
source_position = (*respeaker2imu_ptr_) * source_position;
Point3D source_position_p3d;
source_position_p3d.set_x(source_position[0]);
source_position_p3d.set_y(source_position[1]);
source_position_p3d.set_z(source_position[2]);
return source_position_p3d;
}
double DirectionDetection::EstimateDirection(
std::vector<std::vector<double>>&& channels_vec, const int sample_rate,
const double mic_distance) {
std::vector<torch::Tensor> channels_ts;
auto options = torch::TensorOptions().dtype(torch::kFloat32);
auto options = torch::TensorOptions().dtype(torch::kFloat64);
int size = static_cast<int>(channels_vec[0].size());
for (auto& signal : channels_vec) {
channels_ts.push_back(torch::from_blob(signal.data(), {size}, options));
......@@ -33,10 +61,10 @@ int get_direction(std::vector<std::vector<double>>&& channels_vec,
double tau0, tau1;
int theta0, theta1;
const double max_tau = mic_distance / SOUND_SPEED;
tau0 = gcc_phat(channels_ts[0], channels_ts[2], sample_rate, max_tau, 1);
const double max_tau = mic_distance / kSoundSpeed;
tau0 = GccPhat(channels_ts[0], channels_ts[2], sample_rate, max_tau, 1);
theta0 = asin(tau0 / max_tau) * 180 / M_PI;
tau1 = gcc_phat(channels_ts[1], channels_ts[3], sample_rate, max_tau, 1);
tau1 = GccPhat(channels_ts[1], channels_ts[3], sample_rate, max_tau, 1);
theta1 = asin(tau1 / max_tau) * 180 / M_PI;
int best_guess = 0;
......@@ -48,32 +76,75 @@ int get_direction(std::vector<std::vector<double>>&& channels_vec,
}
best_guess = (-best_guess + 120) % 360;
return best_guess;
return static_cast<double>(best_guess) / 90 * M_PI;
}
/*
* This function computes the offset between the signal sig and the reference
* signal refsig using the Generalized Cross Correlation - Phase Transform
* (GCC-PHAT)method.
*/
double gcc_phat(const torch::Tensor& sig, const torch::Tensor& refsig, int fs,
double max_tau, int interp) {
bool DirectionDetection::LoadExtrinsics(const std::string& yaml_file,
Eigen::Matrix4d* respeaker_extrinsic) {
if (!apollo::cyber::common::PathExists(yaml_file)) {
AINFO << yaml_file << " does not exist!";
return false;
}
YAML::Node node = YAML::LoadFile(yaml_file);
double qw = 0.0;
double qx = 0.0;
double qy = 0.0;
double qz = 0.0;
double tx = 0.0;
double ty = 0.0;
double tz = 0.0;
try {
if (node.IsNull()) {
AINFO << "Load " << yaml_file << " failed! please check!";
return false;
}
qw = node["transform"]["rotation"]["w"].as<double>();
qx = node["transform"]["rotation"]["x"].as<double>();
qy = node["transform"]["rotation"]["y"].as<double>();
qz = node["transform"]["rotation"]["z"].as<double>();
tx = node["transform"]["translation"]["x"].as<double>();
ty = node["transform"]["translation"]["y"].as<double>();
tz = node["transform"]["translation"]["z"].as<double>();
} catch (YAML::Exception& e) {
AERROR << "load camera extrinsic file " << yaml_file
<< " with error, YAML exception:" << e.what();
return false;
}
respeaker_extrinsic->setConstant(0);
Eigen::Quaterniond q;
q.x() = qx;
q.y() = qy;
q.z() = qz;
q.w() = qw;
(*respeaker_extrinsic).block<3, 3>(0, 0) = q.normalized().toRotationMatrix();
(*respeaker_extrinsic)(0, 3) = tx;
(*respeaker_extrinsic)(1, 3) = ty;
(*respeaker_extrinsic)(2, 3) = tz;
(*respeaker_extrinsic)(3, 3) = 1;
return true;
}
double DirectionDetection::GccPhat(const torch::Tensor& sig,
const torch::Tensor& refsig, int fs,
double max_tau, int interp) {
const int n_sig = sig.size(0), n_refsig = refsig.size(0),
n = n_sig + n_refsig;
torch::Tensor psig = at::constant_pad_nd(sig, {0, n_refsig}, 0);
torch::Tensor prefsig = at::constant_pad_nd(refsig, {0, n_sig}, 0);
psig = at::rfft(psig, 1, false, true);
prefsig = at::rfft(prefsig, 1, false, true);
torch::Tensor r = psig * at::conj(prefsig);
torch::Tensor cc = at::irfft(r / at::abs(r), 1, false, true, {interp * n});
ConjugateTensor(&prefsig);
torch::Tensor r = ComplexMultiply(psig, prefsig);
torch::Tensor cc =
at::irfft(r / ComplexAbsolute(r), 1, false, true, {interp * n});
int max_shift = static_cast<int>(interp * n / 2);
if (max_tau != 0)
max_shift = std::min(static_cast<int>(interp * fs * max_tau), max_shift);
auto begin = cc.index({Slice(0, cc.size(0) - max_shift, None)});
auto end = cc.index({Slice(0, None, max_shift + 1)});
auto begin = cc.index({Slice(cc.size(0) - max_shift, None)});
auto end = cc.index({Slice(None, max_shift + 1)});
cc = at::cat({begin, end});
auto ttt = at::argmax(at::abs(cc), 0);
// find max cross correlation index
const int shift = at::argmax(at::abs(cc), 0).item<int>() - max_shift;
const double tau = shift / static_cast<double>(interp * fs);
......@@ -81,5 +152,25 @@ double gcc_phat(const torch::Tensor& sig, const torch::Tensor& refsig, int fs,
return tau;
}
void DirectionDetection::ConjugateTensor(torch::Tensor* tensor) {
tensor->index_put_({"...", 1}, -tensor->index({"...", 1}));
}
torch::Tensor DirectionDetection::ComplexMultiply(const torch::Tensor& a,
const torch::Tensor& b) {
torch::Tensor real = a.index({"...", 0}) * b.index({"...", 0}) -
a.index({"...", 1}) * b.index({"...", 1});
torch::Tensor imag = a.index({"...", 0}) * b.index({"...", 1}) +
a.index({"...", 1}) * b.index({"...", 0});
return at::cat({real.reshape({-1, 1}), imag.reshape({-1, 1})}, 1);
}
torch::Tensor DirectionDetection::ComplexAbsolute(const torch::Tensor& tensor) {
torch::Tensor res = tensor * tensor;
res = at::sqrt(res.sum(1)).reshape({-1, 1});
return res;
}
} // namespace audio
} // namespace apollo
......@@ -19,21 +19,60 @@
#include <algorithm>
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "Eigen/Eigen"
// Eigen 3.3.7: #define ALIVE (0)
// fastrtps: enum ChangeKind_t { ALIVE, ... };
#if defined(ALIVE)
#undef ALIVE
#endif
#include "ATen/ATen.h"
#include "cyber/cyber.h"
#include "torch/torch.h"
#include "cyber/cyber.h"
#include "modules/common/proto/geometry.pb.h"
namespace apollo {
namespace audio {
constexpr double SOUND_SPEED = 343.2;
using apollo::common::Point3D;
class DirectionDetection {
public:
DirectionDetection();
~DirectionDetection();
// Estimates the position of the source of the sound
Point3D EstimateSoundSource(std::vector<std::vector<double>>&& channels_vec,
const std::string& respeaker_extrinsic_file,
const int sample_rate, const double mic_distance);
private:
const double kSoundSpeed = 343.2;
const int kDistance = 50;
std::unique_ptr<Eigen::Matrix4d> respeaker2imu_ptr_;
// Estimates the direction of the source of the sound
double EstimateDirection(std::vector<std::vector<double>>&& channels_vec,
const int sample_rate, const double mic_distance);
bool LoadExtrinsics(const std::string& yaml_file,
Eigen::Matrix4d* respeaker_extrinsic);
// Computes the offset between the signal sig and the reference signal refsig
// using the Generalized Cross Correlation - Phase Transform (GCC-PHAT)method.
double GccPhat(const torch::Tensor& sig, const torch::Tensor& refsig, int fs,
double max_tau, int interp);
double gcc_phat(const torch::Tensor& sig, const torch::Tensor& refsig, int fs,
double max_tau, int interp);
int get_direction(std::vector<std::vector<double>>&& channels_vec,
const int sample_rate, const int mic_distance);
// Libtorch does not support Complex type currently.
void ConjugateTensor(torch::Tensor* tensor);
torch::Tensor ComplexMultiply(const torch::Tensor& a, const torch::Tensor& b);
torch::Tensor ComplexAbsolute(const torch::Tensor& tensor);
};
} // namespace audio
} // namespace apollo
......@@ -15,12 +15,16 @@ cc_proto_library(
proto_library(
name = "audio_proto",
srcs = ["audio.proto"],
deps = [
"//modules/common/proto:geometry_proto",
],
)
py_proto_library(
name = "audio_py_pb2",
deps = [
":audio_proto",
"//modules/common/proto:geometry_py_pb2",
],
)
......
......@@ -18,6 +18,8 @@ syntax = "proto2";
package apollo.audio;
import "modules/common/proto/geometry.proto";
enum MovingResult {
UNKNOWN = 0;
APPROACHING = 1;
......@@ -28,4 +30,5 @@ enum MovingResult {
message AudioDetection {
optional bool is_siren = 1;
optional MovingResult moving_result = 2 [default = UNKNOWN];
optional apollo.common.Point3D position = 3;
}
......@@ -26,4 +26,5 @@ message TopicConf {
message AudioConf {
optional TopicConf topic_conf = 1;
optional string respeaker_extrinsics_path = 2;
}
......@@ -42,6 +42,10 @@ data_files {
source_path: "radar_params"
dest_path: "/apollo/modules/perception/data/params"
}
data_files {
source_path: "microphone_params"
dest_path: "/apollo/modules/audio/conf/respeaker_extrinsics.yaml"
}
data_files {
source_path: "gnss_params/ant_imu_leverarm.yaml"
dest_path: "/apollo/modules/localization/msf/params/gnss_params/ant_imu_leverarm.yaml"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册