未验证 提交 89ccf50b 编写于 作者: D dyning 提交者: GitHub

Merge pull request #1042 from dyning/develop

add anno for tps
...@@ -70,6 +70,13 @@ class LocalizationNetwork(object): ...@@ -70,6 +70,13 @@ class LocalizationNetwork(object):
return initial_bias return initial_bias
def __call__(self, image): def __call__(self, image):
"""
Estimating parameters of geometric transformation
Args:
image: input
Return:
batch_C_prime: the matrix of the geometric transformation
"""
F = self.F F = self.F
loc_lr = self.loc_lr loc_lr = self.loc_lr
if self.model_name == "large": if self.model_name == "large":
...@@ -215,6 +222,14 @@ class GridGenerator(object): ...@@ -215,6 +222,14 @@ class GridGenerator(object):
return batch_C_ex_part_tensor return batch_C_ex_part_tensor
def __call__(self, batch_C_prime, I_r_size): def __call__(self, batch_C_prime, I_r_size):
"""
Generate the grid for the grid_sampler.
Args:
batch_C_prime: the matrix of the geometric transformation
I_r_size: the shape of the input image
Return:
batch_P_prime: the grid for the grid_sampler
"""
C = self.build_C() C = self.build_C()
P = self.build_P(I_r_size) P = self.build_P(I_r_size)
inv_delta_C = self.build_inv_delta_C(C).astype('float32') inv_delta_C = self.build_inv_delta_C(C).astype('float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册