提交 8a47c090 编写于 作者: L Liangliang Zhang

Perception: added graph segmentor.

上级 2ace5663
......@@ -24,4 +24,30 @@ cc_test(
],
)
cc_library(
name = "graph_segmentor",
srcs = [
"graph_segmentor.cc",
],
hdrs = [
"graph_segmentor.h",
],
deps = [
"//framework:cybertron",
":disjoint_set",
],
)
cc_test(
name = "graph_segmentor_test",
size = "small",
srcs = [
"graph_segmentor_test.cc",
],
deps = [
":graph_segmentor",
"@gtest//:main",
],
)
cpplint()
/******************************************************************************
* Copyright 2018 The Apollo Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*****************************************************************************/
#include "modules/perception/common/graph/graph_segmentor.h"
#include <cfloat>
#include "cybertron/common/log.h"
namespace apollo {
namespace perception {
namespace common {
namespace {
float GetThreshold(const size_t sz, const float c) { return c / sz; }
}
void GraphSegmentor::Init(const float initial_threshold) {
initial_threshold_ = initial_threshold;
thresholds_.reserve(kMaxVerticesNum);
thresholds_table_.resize(kMaxThresholdsNum);
thresholds_table_[0] = FLT_MAX;
for (size_t i = 1; i < kMaxThresholdsNum; ++i) {
thresholds_table_[i] = GetThreshold(i, initial_threshold_);
}
}
void GraphSegmentor::SegmentGraph(const int num_vertices, const int num_edges,
Edge* edges, bool need_sort) {
if (edges == nullptr) {
AERROR << "Input Null Edges.";
return;
}
if (need_sort) {
std::sort(edges, edges + num_edges);
}
universe_.Reset(num_vertices);
thresholds_.assign(num_vertices, initial_threshold_);
for (int i = 0; i < num_edges; ++i) {
Edge& edge = edges[i];
int a = universe_.Find(edge.a);
int b = universe_.Find(edge.b);
if (a == b) {
continue;
}
if (edge.w <= thresholds_[a] && edge.w <= thresholds_[b]) {
universe_.Join(a, b);
a = universe_.Find(a);
int size_a = universe_.GetSize(a);
thresholds_[a] =
edge.w + (size_a < static_cast<int>(kMaxThresholdsNum)
? thresholds_table_[size_a]
: GetThreshold(size_a, initial_threshold_));
}
}
}
} // namespace common
} // namespace perception
} // namespace apollo
/******************************************************************************
* Copyright 2018 The Apollo Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*****************************************************************************/
#ifndef MODULES_PERCEPTION_COMMON_GRAPH_SEGMENT_GRAPH_H_
#define MODULES_PERCEPTION_COMMON_GRAPH_SEGMENT_GRAPH_H_
#include <algorithm>
#include <vector>
#include "modules/perception/common/graph/disjoint_set.h"
namespace apollo {
namespace perception {
namespace common {
// @brief: graph edge definition
struct Edge {
float w = 0.0;
int a = 0;
int b = 0;
// @brief: edge comparison
bool operator<(const Edge& other) const { return this->w < other.w; }
};
class GraphSegmentor {
public:
GraphSegmentor() = default;
~GraphSegmentor() = default;
// @brief: initialize thresholds
void Init(const float initial_threshold);
// @brief: segment a graph, generating a disjoint-set forest
// representing the segmentation.
// @params[IN] num_vertices: number of vertices in graph.
// @params[IN] num_edges: number of edges in graph.
// @params[IN] edges: array of Edges.
// @params[OUT] need_sort: whether input edges needs to be sorted
void SegmentGraph(const int num_vertices, const int num_edges, Edge* edges,
bool need_sort = true);
// @brief: return the disjoint-set forest as the segmentation result.
Universe* mutable_universe() { return &universe_; }
const Universe& universe() { return universe_; }
private:
static const size_t kMaxVerticesNum = 10000;
static const size_t kMaxThresholdsNum = 50000;
float initial_threshold_ = 0.f;
std::vector<float> thresholds_;
std::vector<float> thresholds_table_;
Universe universe_;
}; // class GraphSegmentor
} // namespace common
} // namespace perception
} // namespace apollo
#endif // MODULES_PERCEPTION_COMMON_GRAPH_SEGMENT_GRAPH_H_
/******************************************************************************
* Copyright 2018 The Apollo Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*****************************************************************************/
#include "modules/perception/common/graph/graph_segmentor.h"
#include <memory>
#include "gtest/gtest.h"
namespace apollo {
namespace perception {
namespace common {
class GraphSegmentorTest : public testing::Test {
protected:
void SetUp() {
edges_ = new Edge[10];
edges_[0].w = 6.f;
edges_[0].a = 1;
edges_[0].b = 2;
edges_[1].w = 1.f;
edges_[1].a = 1;
edges_[1].b = 3;
edges_[2].w = 5.f;
edges_[2].a = 1;
edges_[2].b = 4;
edges_[3].w = 5.f;
edges_[3].a = 3;
edges_[3].b = 2;
edges_[4].w = 5.f;
edges_[4].a = 3;
edges_[4].b = 4;
edges_[5].w = 3.f;
edges_[5].a = 5;
edges_[5].b = 2;
edges_[6].w = 6.f;
edges_[6].a = 3;
edges_[6].b = 5;
edges_[7].w = 4.f;
edges_[7].a = 3;
edges_[7].b = 0;
edges_[8].w = 2.f;
edges_[8].a = 4;
edges_[8].b = 0;
edges_[9].w = 6.f;
edges_[9].a = 5;
edges_[9].b = 0;
}
void TearDown() {
delete[] edges_;
edges_ = nullptr;
}
Edge* edges_;
const int num_edges_ = 10;
const int num_vertices_ = 6;
};
TEST_F(GraphSegmentorTest, test_edge_comparison) {
EXPECT_TRUE(edges_[1] < edges_[0]);
EXPECT_FALSE(edges_[3] < edges_[4]);
EXPECT_FALSE(edges_[6] < edges_[7]);
}
TEST_F(GraphSegmentorTest, test_segment_graph) {
{
GraphSegmentor segmentor;
segmentor.Init(5.0);
segmentor.SegmentGraph(num_vertices_, num_edges_, nullptr, false);
EXPECT_EQ(0, segmentor.universe().GetSetsNum());
segmentor.SegmentGraph(num_vertices_, num_edges_, edges_, true);
EXPECT_EQ(3, segmentor.universe().GetSetsNum());
}
{
GraphSegmentor segmentor;
segmentor.Init(6.0);
segmentor.SegmentGraph(num_vertices_, num_edges_, edges_);
EXPECT_EQ(1, segmentor.universe().GetSetsNum());
}
{
GraphSegmentor segmentor;
segmentor.Init(2.0);
segmentor.SegmentGraph(num_vertices_, num_edges_, edges_);
EXPECT_EQ(4, segmentor.universe().GetSetsNum());
}
}
} // namespace common
} // namespace perception
} // namespace apollo
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册