提交 67f2e074 编写于 作者: Y Yajia Zhang 提交者: Jiangtao Hu

prediction: added bounding rectangle class in mlp_train

上级 e4e05b32
from vector2d import Vector2
from rotation2d import *
from util import segment_overlap
class BoundingRectangle:
def __init__(self, x, y, theta, length, width):
self.vertices = [None] * 4
dx = 0.5 * length
dy = 0.5 * width
cos_theta = cos(theta)
sin_theta = sin(theta)
self.vertices[0] = rotate_fast(Vector2(dx, -dy), cos_theta, sin_theta)
self.vertices[1] = rotate_fast(Vector2(dx, dy), cos_theta, sin_theta)
self.vertices[2] = rotate_fast(Vector2(-dx, dy), cos_theta, sin_theta)
self.vertices[3] = rotate_fast(Vector2(-dx, -dy), cos_theta, sin_theta)
for i in range(4):
self.vertices[i].x += x
self.vertices[i].y += y
def overlap(self, rect):
for i in range(4):
v0 = self.vertices[i]
v1 = self.vertices[(i + 1) % 4]
range_self = self.project(v0, v1)
range_other = rect.project(v0, v1)
if segment_overlap(range_self[0], range_self[1], range_other[0], range_other[1]) == False:
return False
for i in range(4):
v0 = rect.vertices[i]
v1 = rect.vertices[(i + 1) % 4]
range_self = self.project(v0, v1)
range_other = rect.project(v0, v1)
if segment_overlap(range_self[0], range_self[1], range_other[0], range_other[1]) == False:
return False
return True
def project(self, p0, p1):
v = p1.subtract(p0)
n = v.norm()
rmin = float("inf")
rmax = float("-inf")
for i in range(4):
t = self.vertices[i].subtract(p0)
r = t.dot(v) / n
if r < rmin:
rmin = r
if r > rmax:
rmax = r
return [rmin, rmax]
def print_vertices(self):
for i in range(4):
print(str(self.vertices[i].x) + "\t" + str(self.vertices[i].y) + "\n")
from math import cos, sin
from vector2d import Vector2
def rotate(v, theta):
cos_theta = cos(theta)
sin_theta = sin(theta)
return rotate_fast(v, cos_theta, sin_theta)
def rotate_fast(v, cos_theta, sin_theta):
x = cos_theta * v.x - sin_theta * v.y
y = sin_theta * v.x + cos_theta * v.y
return Vector2(x, y)
def segment_overlap(a, b, x, y):
if b < x or a > y:
return False
return True
def vector_projection_overlap(p0, p1, p2, p3):
v = p1.subtract(p0)
n_square = v.norm_square()
v0 = p2.subtract(p0)
v1 = p3.subtract(p0)
t0 = v0.dot(v)
t1 = v1.dot(v)
if t0 > t1:
t = t0
t0 = t1
t1 = t
return segment_overlap(t0, t1, 0.0, n_square)
from math import sqrt
class Vector2:
def __init__(self, x, y):
self.x = x
self.y = y
def add(self, v):
return Vector2(self.x + v.x, self.y + v.y)
def subtract(self, v):
return Vector2(self.x - v.x, self.y - v.y)
def dot(self, v):
return self.x * v.x + self.y * v.y
def norm(self):
return sqrt(self.x * self.x + self.y * self.y)
def norm_square(self):
return self.x * self.x + self.y * self.y
def print_point(self):
print(str(self.x) + "\t" + str(self.y) + "\n")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册