提交 26856658 编写于 作者: A Aaron Xiao 提交者: Jiangtao Hu

Tools: Add proto flattening util and refactor routing tools.

上级 92f18639
......@@ -19,28 +19,72 @@
import google.protobuf.text_format as text_format
def get_pb_from_text_file(filename, proto_obj):
def get_pb_from_text_file(filename, pb_value):
"""Get a proto from given text file."""
with open(filename, 'r') as file_in:
return text_format.Merge(file_in.read(), proto_obj)
return text_format.Merge(file_in.read(), pb_value)
def get_pb_from_bin_file(filename, proto_obj):
def get_pb_from_bin_file(filename, pb_value):
"""Get a proto from given binary file."""
with open(filename, 'rb') as file_in:
proto_obj.ParseFromString(file_in.read())
return proto_obj
pb_value.ParseFromString(file_in.read())
return pb_value
def get_pb_from_file(filename, proto_obj):
def get_pb_from_file(filename, pb_value):
"""Get a proto from given file by trying binary mode and text mode."""
try:
return get_pb_from_bin_file(filename, proto_obj)
return get_pb_from_bin_file(filename, pb_value)
except:
print 'Info: Cannot parse %s as binary proto.' % filename
try:
return get_pb_from_text_file(filename, proto_obj)
return get_pb_from_text_file(filename, pb_value)
except:
print 'Error: Cannot parse %s as text proto' % filename
return None
def flatten(pb_value, selectors):
"""
Get a flattened tuple from pb_value. Selectors is a list of sub-fields.
Usage:
For a pb_value of:
total_pb = {
me: { name: 'myself' }
children: [{ name: 'child0' }, { name: 'child1' }]
}
my_name, child0_name = flatten(total_pb, ['me.name', 'children[0].name'])
# You get (my_name='myself', child0_name='child0')
children_names = flatten(total_pb, 'children.name')
# You get (children_names=['child0', 'child1'])
"""
def __select_field(val, field):
if hasattr(val, '__len__'):
# Flatten repeated field.
return [__select_field(elem, field) for elem in val]
if not field.endswith(']'):
# Simple field.
return val.__getattribute__(field)
# field contains index: "field[index]".
field, index = field.split('[')
val = val.__getattribute__(field)
index = int(index[:-1])
return val[index] if index < len(val) else None
def __select(val, selector):
for field in selector.split('.'):
val = __select_field(val, field)
if val is None:
return None
return val
# Return the single result for single selector.
if type(selectors) is str:
return __select(pb_value, selectors)
# Return tuple result for multiple selectors.
return (__select(pb_value, selector) for selector in selectors)
......@@ -16,12 +16,16 @@
# limitations under the License.
###############################################################################
import sys
import itertools
import sys
import matplotlib.pyplot as plt
import common.proto_utils as proto_utils
import debug_topo
import gen.topo_graph_pb2 as topo_graph_pb2
import gen.router_pb2 as router_pb2
from modules.routing.proto.routing_pb2 import RoutingResponse
from modules.routing.proto.topo_graph_pb2 import Graph
color_iter = itertools.cycle(
['navy', 'c', 'cornflowerblue', 'gold', 'darkorange'])
......@@ -31,20 +35,10 @@ g_center_point_dict = {}
def get_center_of_passage_region(region):
"""Get center of passage region center curve"""
center_points = []
for seg in region.segment:
center_points.append(g_center_point_dict[seg.id])
center_points = [g_center_point_dict[seg.id] for seg in region.segment]
return center_points[len(center_points) // 2]
def read_routing_result(file_name):
"""Read routing result"""
fin = open(file_name)
result = router_pb2.RoutingResult()
result.ParseFromString(fin.read())
return result
def plot_region(region, color):
"Plot passage region"
for seg in region.segment:
......@@ -134,12 +128,9 @@ if __name__ == '__main__':
sys.exit(0)
print 'Please wait for loading data...'
file_name = sys.argv[1]
fin = open(file_name)
graph = topo_graph_pb2.Graph()
graph.ParseFromString(fin.read())
for nd in graph.node:
g_central_curve_dict[nd.lane_id] = nd.central_curve
topo_graph_file = sys.argv[1]
graph = proto_utils.get_pb_from_bin_file(topo_graph_file, Graph())
g_central_curve_dict = {nd.lane_id : nd.central_curve for nd in graph.node}
plt.ion()
while 1:
......@@ -151,7 +142,9 @@ if __name__ == '__main__':
if argv[0] == 'q':
sys.exit(0)
elif argv[0] == 'p':
result = read_routing_result(sys.argv[2])
routing_result_file = sys.argv[2]
result = proto_utils.get_pb_from_bin_file(routing_result_file,
RoutingResponse())
plot_result(result, g_central_curve_dict)
else:
print '[ERROR] wrong command'
......
......@@ -16,15 +16,17 @@
# limitations under the License.
###############################################################################
import sys
import itertools
import os
import sys
import gflags
import matplotlib.pyplot as plt
import debug_topo
import modules.routing.proto.topo_graph_pb2 as topo_graph_pb2
import util
import gflags
import os
color_iter = itertools.cycle(
['navy', 'c', 'cornflowerblue', 'gold', 'darkorange'])
......
......@@ -16,10 +16,13 @@
# limitations under the License.
###############################################################################
import sys
import math
import itertools
import math
import sys
import matplotlib.pyplot as plt
import common.proto_utils as proto_utils
import modules.routing.proto.topo_graph_pb2 as topo_graph_pb2
import util
......@@ -27,14 +30,6 @@ color_iter = itertools.cycle(
['navy', 'c', 'cornflowerblue', 'gold', 'darkorange'])
def downsample_array(array):
"""down sample given array"""
skip = 5
result = array[::skip]
result.append(array[-1])
return result
def calculate_s(px, py):
"""Calculate s array based on x and y arrays"""
dis = 0.0
......@@ -46,25 +41,10 @@ def calculate_s(px, py):
return ps
def extract_line(line):
"""extract line, return x array and y array"""
px = []
py = []
for pt in line.point:
px.append(float(pt.x))
py.append(float(pt.y))
return px, py
def draw_line(line, color):
"""draw line, return x array and y array"""
px = []
py = []
for pt in line.point:
px.append(float(pt.x))
py.append(float(pt.y))
px = downsample_array(px)
py = downsample_array(py)
px, py = proto_utils.flatten(line.point, ['x', 'y'])
px, py = util.downsample_array(px), util.downsample_array(py)
plt.gca().plot(px, py, color=color, lw=3, alpha=0.8)
return px, py
......@@ -106,11 +86,10 @@ def plot_central_curve_with_s_range(central_curve, start_s, end_s, color):
"""plot topology graph node with given start and end s, return middle point"""
node_x = []
node_y = []
plot_length = 0.0
for curve in central_curve.segment:
px, py = extract_line(curve.line_segment)
node_x = node_x + px
node_y = node_y + py
px, py = proto_utils.flatten(curve.line_segment.point, ['x', 'y'])
node_x.extend(px)
node_y.extend(py)
start_plot_index = 0
end_plot_index = len(node_x)
node_s = calculate_s(node_x, node_y)
......
......@@ -15,9 +15,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
"""Show road."""
import sys
import matplotlib.pyplot as plt
import common.proto_utils as proto_utils
import util
g_color = [
......@@ -31,13 +35,8 @@ def draw_line(line_segment, color):
:param line_segment:
:return: none
"""
px = []
py = []
for p in line_segment.point:
px.append(float(p.x))
py.append(float(p.y))
px = downsample_array(px)
py = downsample_array(py)
px, py = proto_utils.flatten(line_segment.point, ['x', 'y'])
px, py = downsample_array(px), downsample_array(py)
plt.gca().plot(px, py, lw=10, alpha=0.8, color=color)
return px[len(px) // 2], py[len(py) // 2]
......@@ -76,13 +75,8 @@ def draw_boundary(line_segment):
:param line_segment:
:return:
"""
px = []
py = []
for p in line_segment.point:
px.append(float(p.x))
py.append(float(p.y))
px = downsample_array(px)
py = downsample_array(py)
px, py = proto_utils.flatten(line_segment.point, ['x', 'y'])
px, py = downsample_array(px), downsample_array(py)
plt.gca().plot(px, py, 'k')
......@@ -121,15 +115,10 @@ def draw_map(drivemap):
for road in drivemap.road:
lanes = []
for sec in road.section:
for lane in sec.lane_id:
lanes.append(lane.id)
lanes.extend(proto_utils.flatten(sec.lane_id, 'id'))
road_lane_set.append(lanes)
for lane in drivemap.lane:
#print lane.type
#print lane.central_curve
#break
#print [f.name for f in lane.central_curve.DESCRIPTOR.fields]
for curve in lane.central_curve.segment:
if curve.HasField('line_segment'):
road_idx = get_road_index_of_lane(lane.id.id, road_lane_set)
......@@ -142,7 +131,6 @@ def draw_map(drivemap):
#break
#if curve.HasField('arc'):
# draw_arc(curve.arc)
#print "arc"
for curve in lane.left_boundary.curve.segment:
if curve.HasField('line_segment'):
......
......@@ -91,10 +91,9 @@ def onclick(event):
print 'cmd>',
def downsample_array(array):
def downsample_array(array, step=5):
"""down sample given array"""
skip = 5
result = array[::skip]
result = array[::step]
result.append(array[-1])
return result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册