未验证 提交 a065a242 编写于 作者: W whs 提交者: GitHub

【2.0 API】Enhance affine grid operator (#26385)

* Enhance affine grid operator:
1. Add cuda kernel
2. Add align corners options
test=develop

* Move new affine_grid api to functional
test=develop

* Add CUDA kernel for affine_grid.
test=develop

* Add more unitest for grid sample API
test=develop
上级 6f69fbc8
......@@ -28,10 +28,15 @@ using Tensor = framework::Tensor;
template <typename T>
struct Linspace<paddle::platform::CPUDeviceContext, T> {
void operator()(T start, T end, int count, framework::Tensor* numbers,
void operator()(T start, T end, int count, bool align_corners,
framework::Tensor* numbers,
const framework::ExecutionContext& ctx) {
T* number_data = numbers->mutable_data<T>({count}, platform::CPUPlace());
T slice = (end - start) / (T)(count - 1);
if (!align_corners) {
slice = (end - start) / (T)count;
start *= (T)(count - 1) / (T)count;
}
for (int i = 0; i < count; ++i) {
number_data[i] = start + (T)i * slice;
}
......@@ -130,6 +135,10 @@ class AffineGridOpMaker : public framework::OpProtoAndCheckerMaker {
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(true);
AddAttr<bool>("align_corners",
"(bool, default false) Whether to align the corners of input"
"and ouput.")
.SetDefault(true);
AddAttr<std::vector<int>>(
"output_shape",
"The target output image shape with format [N, C, H, W].")
......@@ -164,10 +173,12 @@ class AffineGridOpMaker : public framework::OpProtoAndCheckerMaker {
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]]]
C[0] is the coordinates in height axis and C[1] is the coordinates in width axis.
C[0] is the coordinates in height axis and C[1] is the coordinates in
width axis.
Step2:
Tanspose and reshape C to shape [H * W, 2] and append ones to last dimension. The we get:
Tanspose and reshape C to shape [H * W, 2] and append ones to last
dimension. The we get:
C_ = [[-1. -1. 1. ]
[-0.5 -1. 1. ]
[ 0. -1. 1. ]
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/affine_grid_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) {
CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
}
template <typename T>
struct Linspace<paddle::platform::CUDADeviceContext, T> {
void operator()(T start, T end, int count, bool align_corners,
framework::Tensor* numbers,
const framework::ExecutionContext& ctx) {
T* number_data = numbers->mutable_data<T>({count}, ctx.GetPlace());
T slice = (end - start) / (T)(count - 1);
if (!align_corners) {
slice = (end - start) / (T)count;
start *= (T)(count - 1) / (T)count;
}
auto stream = ctx.cuda_device_context().stream();
int block = 512;
int grid = (count + block - 1) / block;
LinspaceKernel<T><<<grid, block, 0, stream>>>(start, slice, count,
number_data);
}
};
template <typename T>
__global__ void affine_grid_kernel(const int count, int n, int out_h, int out_w,
T h_start, T w_start, T h_step, T w_step,
const T* theta, // N, 2, 3
T* output) {
CUDA_KERNEL_LOOP(index, count) {
int w = index % out_w;
int h = (index / out_w) % out_h;
int n = index / (out_w * out_h);
T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);
int theta_offset = n * 6; // 2 * 3;
// affine from (h_coor, w_coor) to (x, y)
output[index * 2] = theta[theta_offset] * h_coor +
theta[theta_offset + 1] * w_coor +
theta[theta_offset + 2];
output[index * 2 + 1] = theta[theta_offset + 3] * h_coor +
theta[theta_offset + 4] * w_coor +
theta[theta_offset + 5];
}
}
template <typename T>
__global__ void affine_grid_grad_kernel(const int count, int n, int out_h,
int out_w, T h_start, T w_start,
T h_step, T w_step,
const T* out_grad, // N, H, W, 2
T* theta_grad) { // N, 2, 3
CUDA_KERNEL_LOOP(index, count) {
int w = index % out_w;
int h = (index / out_w) % out_h;
int n = index / (out_w * out_h);
T h_coor = h_step * static_cast<T>(h) + static_cast<T>(h_start);
T w_coor = w_step * static_cast<T>(w) + static_cast<T>(w_start);
int theta_offset = n * 6; // 2 * 3;
T out_grad_x = out_grad[index * 2];
atomicAdd(theta_grad + theta_offset, out_grad_x * h_coor);
atomicAdd(theta_grad + theta_offset + 1, out_grad_x * w_coor);
atomicAdd(theta_grad + theta_offset + 2, out_grad_x);
T out_grad_y = out_grad[index * 2 + 1];
atomicAdd(theta_grad + theta_offset + 3, out_grad_y * h_coor);
atomicAdd(theta_grad + theta_offset + 4, out_grad_y * w_coor);
atomicAdd(theta_grad + theta_offset + 5, out_grad_y);
}
}
template <typename T>
class AffineGridOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* theta = ctx.Input<Tensor>("Theta");
int n = theta->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
auto align_corners = ctx.Attr<bool>("align_corners");
int h = 0;
int w = 0;
if (size_attr.size() == 0) {
auto* output_shape = ctx.Input<Tensor>("OutputShape");
Tensor h_sizes;
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
const int* h_size_data = h_sizes.data<int>();
h = h_size_data[2];
w = h_size_data[3];
} else {
h = size_attr[2];
w = size_attr[3];
}
auto* output = ctx.Output<Tensor>("Output");
T* out_data = output->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
T h_step;
T w_step;
T h_start = -1;
T w_start = -1;
if (align_corners) {
h_step = static_cast<T>(2) / static_cast<T>(h - 1);
w_step = static_cast<T>(2) / static_cast<T>(w - 1);
} else {
h_step = static_cast<T>(2) / static_cast<T>(h);
w_step = static_cast<T>(2) / static_cast<T>(w);
h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
}
const int count = n * h * w;
int block = 512;
int grid = (count + block - 1) / block;
auto cu_stream = ctx.cuda_device_context().stream();
affine_grid_kernel<<<grid, block, 0, cu_stream>>>(
count, n, h, w, h_start, w_start, h_step, w_step,
theta->data<T>(), // N, 2, 3
out_data);
}
};
template <typename T>
class AffineGridGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
int n = output_grad->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
auto align_corners = ctx.Attr<bool>("align_corners");
int h = 0;
int w = 0;
if (size_attr.size() == 0) {
auto* output_shape = ctx.Input<Tensor>("OutputShape");
Tensor h_sizes;
framework::TensorCopy(*output_shape, platform::CPUPlace(), &h_sizes);
const int* h_size_data = h_sizes.data<int>();
h = h_size_data[2];
w = h_size_data[3];
} else {
h = size_attr[2];
w = size_attr[3];
}
T* theta_grad_data = theta_grad->mutable_data<T>({n, 2, 3}, ctx.GetPlace());
math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.cuda_device_context(), theta_grad, static_cast<T>(0));
T h_step;
T w_step;
T h_start = -1;
T w_start = -1;
if (align_corners) {
h_step = static_cast<T>(2) / static_cast<T>(h - 1);
w_step = static_cast<T>(2) / static_cast<T>(w - 1);
} else {
h_step = static_cast<T>(2) / static_cast<T>(h);
w_step = static_cast<T>(2) / static_cast<T>(w);
h_start *= static_cast<T>(h - 1) / static_cast<T>(h);
w_start *= static_cast<T>(w - 1) / static_cast<T>(w);
}
const int count = n * h * w;
VLOG(3) << "count: " << count << "; h_step: " << h_step
<< "; w_step: " << w_step << "; h_start: " << h_start
<< "; w_start: " << w_start;
int block = 512;
int grid = (count + block - 1) / block;
auto cu_stream = ctx.cuda_device_context().stream();
affine_grid_grad_kernel<<<grid, block, 0, cu_stream>>>(
count, n, h, w, h_start, w_start, h_step, w_step,
output_grad->data<T>(), theta_grad_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(affine_grid, ops::AffineGridOpCUDAKernel<float>,
ops::AffineGridOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(affine_grid_grad,
ops::AffineGridGradOpCUDAKernel<float>,
ops::AffineGridGradOpCUDAKernel<double>);
......@@ -37,12 +37,13 @@ using Array4 = Eigen::DSizes<int64_t, 4>;
*/
template <typename DeviceContext, typename T>
struct Linspace {
void operator()(T start, T end, int count, framework::Tensor* numbers,
void operator()(T start, T end, int count, bool align_corners,
framework::Tensor* numbers,
const framework::ExecutionContext& ctx);
};
template <typename DeviceContext, typename T>
inline void GetIdxMap(int n, int h, int w, Tensor* grid,
inline void GetIdxMap(int n, int h, int w, bool align_corners, Tensor* grid,
const framework::ExecutionContext& ctx) {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
grid->mutable_data<T>({n, h, w, 3}, ctx.GetPlace());
......@@ -50,16 +51,19 @@ inline void GetIdxMap(int n, int h, int w, Tensor* grid,
// Get indexes of height with shape [height, width, 1]
Tensor h_idx;
Linspace<DeviceContext, T> linspace;
linspace((T)-1, (T)1, h, &h_idx, ctx);
linspace((T)-1, (T)1, h, align_corners, &h_idx, ctx);
auto h_idx_t = EigenTensor<T, 1>::From(h_idx);
// Get indexes of width with shape [height, width, 1]
Tensor w_idx;
linspace((T)-1, (T)1, w, &w_idx, ctx);
linspace((T)-1, (T)1, w, align_corners, &w_idx, ctx);
auto w_idx_t = EigenTensor<T, 1>::From(w_idx);
// Get constant ones tensor with shape [height, width, 1]
Tensor ones;
ones.mutable_data<T>({h, w, 1}, ctx.GetPlace());
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant((T)1);
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), &ones, static_cast<T>(1));
auto ones_t = EigenTensor<T, 3>::From(ones);
// Get grid tensor with shape [n, h, w, 3] by concatenating h_idx, w_idx and
// ones
Tensor w_idx_map;
......@@ -74,11 +78,9 @@ inline void GetIdxMap(int n, int h, int w, Tensor* grid,
Tensor w_h_one_idx_map;
w_h_one_idx_map.mutable_data<T>({h, w, 3}, ctx.GetPlace());
auto w_h_one_idx_map_t = EigenTensor<T, 3>::From(w_h_one_idx_map);
w_idx_map_t.device(place) = w_idx_t.reshape(Array2(1, w))
.broadcast(Array2(h, 1))
.reshape(Array3(h, w, 1));
h_idx_map_t.device(place) = h_idx_t.reshape(Array2(1, h))
.broadcast(Array2(w, 1))
.shuffle(Array2(1, 0))
......@@ -97,6 +99,7 @@ class AffineGridOpKernel : public framework::OpKernel<T> {
auto* theta = ctx.Input<Tensor>("Theta");
int n = theta->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
auto align_corners = ctx.Attr<bool>("align_corners");
int h = 0;
int w = 0;
if (size_attr.size() == 0) {
......@@ -116,7 +119,7 @@ class AffineGridOpKernel : public framework::OpKernel<T> {
ctx.template device_context<DeviceContext>(), output,
static_cast<T>(0));
Tensor grid;
GetIdxMap<DeviceContext, T>(n, h, w, &grid, ctx);
GetIdxMap<DeviceContext, T>(n, h, w, align_corners, &grid, ctx);
// output = grid * theta.T
// TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = math::GetBlas<DeviceContext, T>(ctx);
......@@ -140,6 +143,7 @@ class AffineGridGradOpKernel : public framework::OpKernel<T> {
auto theta_grad = ctx.Output<Tensor>(framework::GradVarName("Theta"));
int n = output_grad->dims()[0];
auto size_attr = ctx.Attr<std::vector<int>>("output_shape");
auto align_corners = ctx.Attr<bool>("align_corners");
int h = 0;
int w = 0;
if (size_attr.size() == 0) {
......@@ -158,7 +162,7 @@ class AffineGridGradOpKernel : public framework::OpKernel<T> {
ctx.template device_context<DeviceContext>(), theta_grad,
static_cast<T>(0));
Tensor grid;
GetIdxMap<DeviceContext, T>(n, h, w, &grid, ctx);
GetIdxMap<DeviceContext, T>(n, h, w, align_corners, &grid, ctx);
// output = grid * theta.T
// TODO(wanghaoshuang): Refine batched matrix multiply
auto blas = math::GetBlas<DeviceContext, T>(ctx);
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle import fluid, nn
import paddle.fluid.dygraph as dg
import paddle.nn.functional as F
import paddle.fluid.initializer as I
import unittest
class AffineGridTestCase(unittest.TestCase):
def __init__(self,
methodName='runTest',
theta_shape=(20, 2, 3),
output_shape=[20, 2, 5, 7],
align_corners=True,
dtype="float32",
invalid_theta=False,
variable_output_shape=False):
super(AffineGridTestCase, self).__init__(methodName)
self.theta_shape = theta_shape
self.output_shape = output_shape
self.align_corners = align_corners
self.dtype = dtype
self.invalid_theta = invalid_theta
self.variable_output_shape = variable_output_shape
def setUp(self):
self.theta = np.random.randn(*(self.theta_shape)).astype(self.dtype)
def fluid_layer(self, place):
# align_corners = True
main = fluid.Program()
start = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, start):
theta_var = fluid.data(
"input", self.theta_shape, dtype=self.dtype)
y_var = fluid.layers.affine_grid(theta_var, self.output_shape)
feed_dict = {"input": self.theta}
exe = fluid.Executor(place)
exe.run(start)
y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var])
return y_np
def functional(self, place):
main = fluid.Program()
start = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, start):
theta_var = fluid.data(
"input", self.theta_shape, dtype=self.dtype)
y_var = F.affine_grid(
theta_var,
self.output_shape,
align_corners=self.align_corners)
feed_dict = {"input": self.theta}
exe = fluid.Executor(place)
exe.run(start)
y_np, = exe.run(main, feed=feed_dict, fetch_list=[y_var])
return y_np
def paddle_dygraph_layer(self):
theta_var = dg.to_variable(
self.theta) if not self.invalid_theta else "invalid"
output_shape = dg.to_variable(
self.
output_shape) if self.variable_output_shape else self.output_shape
y_var = F.affine_grid(
theta_var, output_shape, align_corners=self.align_corners)
y_np = y_var.numpy()
return y_np
def _test_equivalence(self, place):
place = fluid.CPUPlace()
result1 = self.fluid_layer(place)
result2 = self.functional(place)
with dg.guard(place):
result3 = self.paddle_dygraph_layer()
if self.align_corners:
np.testing.assert_array_almost_equal(result1, result2)
np.testing.assert_array_almost_equal(result2, result3)
def runTest(self):
place = fluid.CPUPlace()
self._test_equivalence(place)
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
self._test_equivalence(place)
class AffineGridErrorTestCase(AffineGridTestCase):
def runTest(self):
place = fluid.CPUPlace()
with dg.guard(place):
with self.assertRaises(ValueError):
self.paddle_dygraph_layer()
def add_cases(suite):
suite.addTest(AffineGridTestCase(methodName='runTest'))
suite.addTest(AffineGridTestCase(methodName='runTest', align_corners=True))
suite.addTest(AffineGridTestCase(methodName='runTest', align_corners=False))
suite.addTest(
AffineGridTestCase(
methodName='runTest', variable_output_shape=True))
suite.addTest(
AffineGridTestCase(
methodName='runTest',
theta_shape=(20, 2, 3),
output_shape=[20, 1, 7, 7],
align_corners=True))
def add_error_cases(suite):
suite.addTest(
AffineGridErrorTestCase(
methodName='runTest', output_shape="not_valid"))
suite.addTest(
AffineGridErrorTestCase(
methodName='runTest',
invalid_theta=True)) # to test theta not variable error checking
def load_tests(loader, standard_tests, pattern):
suite = unittest.TestSuite()
add_cases(suite)
add_error_cases(suite)
return suite
if __name__ == '__main__':
unittest.main()
......@@ -17,14 +17,20 @@ import numpy as np
from op_test import OpTest
def AffineGrid(theta, size):
def AffineGrid(theta, size, align_corners):
n = size[0]
w = size[3]
h = size[2]
h_factor = w_factor = 1
if not align_corners:
h_factor = (h - 1) / float(h)
w_factor = (w - 1) / float(w)
h_idx = np.repeat(
np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[:, :, np.newaxis]
np.linspace(-1, 1, h)[np.newaxis, :], w,
axis=0).T[:, :, np.newaxis] * h_factor
w_idx = np.repeat(
np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[:, :, np.newaxis]
np.linspace(-1, 1, w)[np.newaxis, :], h,
axis=0)[:, :, np.newaxis] * w_factor
grid = np.concatenate(
[w_idx, h_idx, np.ones([h, w, 1])], axis=2) # h * w * 3
grid = np.repeat(grid[np.newaxis, :], size[0], axis=0) # n * h * w *3
......@@ -45,12 +51,17 @@ class TestAffineGridOp(OpTest):
theta = np.random.randint(1, 3, self.theta_shape).astype("float32")
theta = np.ones(self.theta_shape).astype("float32")
self.inputs = {'Theta': theta}
self.attrs = {"use_cudnn": True}
self.attrs = {
"use_cudnn": self.use_cudnn,
"align_corners": self.align_corners
}
if self.dynamic_shape:
self.inputs['OutputShape'] = self.output_shape
else:
self.attrs['output_shape'] = self.output_shape
self.outputs = {'Output': AffineGrid(theta, self.output_shape)}
self.outputs = {
'Output': AffineGrid(theta, self.output_shape, self.align_corners)
}
def test_check_output(self):
self.check_output()
......@@ -62,6 +73,8 @@ class TestAffineGridOp(OpTest):
self.theta_shape = (17, 2, 3)
self.output_shape = np.array([17, 2, 5, 7]).astype("int32")
self.dynamic_shape = False
self.use_cudnn = False
self.align_corners = True
class TestAffineGridOpCase1(TestAffineGridOp):
......@@ -69,6 +82,35 @@ class TestAffineGridOpCase1(TestAffineGridOp):
self.theta_shape = (20, 2, 3)
self.output_shape = np.array([20, 2, 5, 7]).astype("int32")
self.dynamic_shape = True
self.use_cudnn = True
self.align_corners = True
class TestAffineGridOpCase2(TestAffineGridOp):
def initTestCase(self):
self.theta_shape = (20, 2, 3)
self.output_shape = np.array([20, 2, 5, 7]).astype("int32")
self.dynamic_shape = True
self.use_cudnn = False
self.align_corners = True
class TestAffineGridOpCase3(TestAffineGridOp):
def initTestCase(self):
self.theta_shape = (20, 2, 3)
self.output_shape = np.array([20, 2, 5, 7]).astype("int32")
self.dynamic_shape = True
self.use_cudnn = False
self.align_corners = False
class TestAffineGridOpCase4(TestAffineGridOp):
def initTestCase(self):
self.theta_shape = (25, 2, 3)
self.output_shape = np.array([25, 2, 5, 6]).astype("int32")
self.dynamic_shape = False
self.use_cudnn = False
self.align_corners = False
if __name__ == '__main__':
......
......@@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ...fluid.data_feeder import check_variable_and_dtype
from ...device import get_cudnn_version
from ...fluid.framework import core, in_dygraph_mode, Variable
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import core, in_dygraph_mode
from ...fluid.data_feeder import check_variable_and_dtype
from ...fluid import dygraph_utils
import numpy as np
# TODO: define specitial functions used in computer vision task
from ...fluid.layers import affine_channel #DEFINE_ALIAS
from ...fluid.layers import affine_grid #DEFINE_ALIAS
from ...fluid.layers import anchor_generator #DEFINE_ALIAS
from ...fluid.layers import bipartite_match #DEFINE_ALIAS
from ...fluid.layers import box_clip #DEFINE_ALIAS
......@@ -93,10 +96,98 @@ __all__ = [
'yolov3_loss'
]
from ...fluid import core, dygraph_utils
from ...fluid.framework import Variable, in_dygraph_mode
from ...device import get_cudnn_version
import numpy as np
def affine_grid(theta, out_shape, align_corners=True, name=None):
"""
It generates a grid of (x,y) coordinates using the parameters of
the affine transformation that correspond to a set of points where
the input feature map should be sampled to produce the transformed
output feature map.
Args:
theta (Tensor) - A tensor with shape [N, 2, 3]. It contains a batch of affine transform parameters.
The data type can be float32 or float64.
out_shape (Tensor | list | tuple): The shape of target output with format [batch_size, channel, height, width].
``out_shape`` can be a Tensor or a list or tuple. The data
type must be int32.
align_corners(bool): Whether to align corners of target feature map and source feature map. Default: True.
name(str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, A Tensor with shape [batch_size, H, W, 2] while 'H' and 'W' are the height and width of feature map in affine transformation. The data type is the same as `theta`.
Raises:
ValueError: If the type of arguments is not supported.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
# theta shape = [1, 2, 3]
theta = np.array([[[-0.7, -0.4, 0.3],
[ 0.6, 0.5, 1.5]]]).astype("float32")
theta_t = paddle.to_tensor(theta)
y_t = F.affine_grid(
theta_t,
[1, 2, 3, 3],
align_corners=False)
print(y_t.numpy())
#[[[[ 1.0333333 0.76666665]
# [ 0.76666665 1.0999999 ]
# [ 0.5 1.4333333 ]]
#
# [[ 0.5666667 1.1666666 ]
# [ 0.3 1.5 ]
# [ 0.03333333 1.8333334 ]]
#
# [[ 0.10000002 1.5666667 ]
# [-0.16666666 1.9000001 ]
# [-0.43333334 2.2333333 ]]]]
"""
helper = LayerHelper('affine_grid')
if not isinstance(theta, Variable):
raise ValueError("The theta should be a Tensor.")
check_variable_and_dtype(theta, 'theta', ['float32', 'float64'],
'affine_grid')
cudnn_version = get_cudnn_version()
if cudnn_version is not None and cudnn_version >= 6000 and align_corners:
use_cudnn = True
else:
use_cudnn = False
if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \
isinstance(out_shape, Variable)):
raise ValueError("The out_shape should be a list, tuple or Tensor.")
if in_dygraph_mode():
_out_shape = out_shape.numpy().tolist() if isinstance(
out_shape, Variable) else out_shape
return core.ops.affine_grid(theta, "output_shape", _out_shape,
"align_corners", align_corners, "use_cudnn",
use_cudnn)
out = helper.create_variable_for_type_inference(theta.dtype)
ipts = {'Theta': theta}
attrs = {"align_corners": align_corners, "use_cudnn": use_cudnn}
if isinstance(out_shape, Variable):
ipts['OutputShape'] = out_shape
check_variable_and_dtype(out_shape, 'out_shape', ['int32'],
'affine_grid')
else:
attrs['output_shape'] = out_shape
helper.append_op(
type='affine_grid',
inputs=ipts,
outputs={'Output': out},
attrs=None if len(attrs) == 0 else attrs)
return out
def grid_sample(x,
......@@ -166,8 +257,10 @@ def grid_sample(x,
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns: Tensor, The shape of output is [N, C, grid_H, grid_W] in which `grid_H` is the height of grid
and `grid_W` is the width of grid. The data type is same as input tensor.
Returns:
Tensor, The shape of output is [N, C, grid_H, grid_W] in which `grid_H` is the height of grid and `grid_W` is the width of grid. The data type is same as input tensor.
Examples:
.. code-block:: python
import paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册