data x and y is indexing the 3rd dimension (in height dimension),
data x and y is indexing the 3rd dimension (in height dimension),
finally results is the bilinear interpolation or nearest value of 4 nearest corner
finally results is the bilinear interpolation or nearest value of 4 nearest corner
points. The output tensor shape will be [N, C, H, W].
points. The output tensor shape will be [N, C, H, W].
.. code-block:: text
Step 1:
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
.. code-block:: text
.. code-block:: text
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points or nearest interpolate point value
interpolate point value by 4 nearest points or nearest interpolate point value
by nearest point.
by nearest point.
.. code-block:: text
wn ------- y_n ------- en
wn ------- y_n ------- en
| | |
| | |
| d_n |
| d_n |
...
@@ -224,6 +232,7 @@ def grid_sample(x,
...
@@ -224,6 +232,7 @@ def grid_sample(x,
| d_s |
| d_s |
| | |
| | |
ws ------- y_s ------- wn
ws ------- y_s ------- wn
For bilinear interpolation:
For bilinear interpolation:
x_w = floor(x) // west side x coord
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
x_e = x_w + 1 // east side x coord
...
@@ -237,8 +246,10 @@ def grid_sample(x,
...
@@ -237,8 +246,10 @@ def grid_sample(x,
en = X[:, :, y_n, x_e] // north-east point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
+ ws * d_e * d_n + es * d_w * d_n
Args:
Args:
x(Tensor): The input tensor, which is a 4-d tensor with shape
x(Tensor): The input tensor, which is a 4-d tensor with shape
[N, C, H, W], N is the batch size, C is the channel
[N, C, H, W], N is the batch size, C is the channel
...
@@ -262,7 +273,9 @@ def grid_sample(x,
...
@@ -262,7 +273,9 @@ def grid_sample(x,
Tensor, The shape of output is [N, C, grid_H, grid_W] in which `grid_H` is the height of grid and `grid_W` is the width of grid. The data type is same as input tensor.
Tensor, The shape of output is [N, C, grid_H, grid_W] in which `grid_H` is the height of grid and `grid_W` is the width of grid. The data type is same as input tensor.