提交 ff6329bd 编写于 作者: D dengkaipeng

fix some inappropriate expressions in api doc for grid_sampler. test=develop

上级 593e1b18
......@@ -60,10 +60,10 @@ class CUDNNGridSampleOpKernel : public framework::OpKernel<T> {
DataLayout::kNCHW, framework::vectorize2int(output->dims()));
CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerForward(
handle, cudnn_st_desc, CudnnDataType<T>::kOne(), cudnn_input_desc, input_data,
grid_data, CudnnDataType<T>::kZero(), cudnn_output_desc, output_data));
handle, cudnn_st_desc, CudnnDataType<T>::kOne(), cudnn_input_desc,
input_data, grid_data, CudnnDataType<T>::kZero(), cudnn_output_desc,
output_data));
}
};
template <typename T>
......@@ -94,25 +94,29 @@ class CUDNNGridSampleGradOpKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>();
const T* grid_data = grid->data<T>();
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>(output_grad_dims, ctx.GetPlace());
T* grid_grad_data = grid_grad->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
T* input_grad_data =
input_grad->mutable_data<T>(output_grad_dims, ctx.GetPlace());
T* grid_grad_data =
grid_grad->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor input_grad_desc;
ScopedTensorDescriptor output_grad_desc;
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t cudnn_input_grad_desc = input_grad_desc.descriptor<T>(
cudnnTensorDescriptor_t cudnn_input_grad_desc =
input_grad_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(input_grad->dims()));
cudnnTensorDescriptor_t cudnn_output_grad_desc = output_grad_desc.descriptor<T>(
cudnnTensorDescriptor_t cudnn_output_grad_desc =
output_grad_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(output_grad->dims()));
CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerBackward(
handle, cudnn_st_dest, CudnnDataType<T>::kOne(),
cudnn_input_desc, input_data, CudnnDataType<T>::kZero(),
cudnn_input_grad_desc, input_grad_data, CudnnDataType<T>::kOne(),
cudnn_output_grad_desc, output_grad_data, grid_data,
CudnnDataType<T>::kZero(), grid_grad_data));
handle, cudnn_st_dest, CudnnDataType<T>::kOne(), cudnn_input_desc,
input_data, CudnnDataType<T>::kZero(), cudnn_input_grad_desc,
input_grad_data, CudnnDataType<T>::kOne(), cudnn_output_grad_desc,
output_grad_data, grid_data, CudnnDataType<T>::kZero(),
grid_grad_data));
}
};
......
......@@ -36,12 +36,19 @@ class GridSampleOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");
PADDLE_ENFORCE(x_dims.size() == 4, "Input(X) of GridSampleOp should be 4-D Tensor.");
PADDLE_ENFORCE(grid_dims.size() == 4, "Input(Grid) of GridSampleOp should be 4-D Tensor.");
PADDLE_ENFORCE(x_dims.size() == 4,
"Input(X) of GridSampleOp should be 4-D Tensor.");
PADDLE_ENFORCE(grid_dims.size() == 4,
"Input(Grid) of GridSampleOp should be 4-D Tensor.");
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], "Input(X) and Input(Grid) dims[0] should be equal.");
PADDLE_ENFORCE_EQ(grid_dims[1], x_dims[2], "Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
PADDLE_ENFORCE_EQ(grid_dims[2], x_dims[3], "Input(X) dims[3] and Input(Grid) dims[2] should be equal.");
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
"Input(X) and Input(Grid) dims[0] should be equal.");
PADDLE_ENFORCE_EQ(
grid_dims[1], x_dims[2],
"Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
PADDLE_ENFORCE_EQ(
grid_dims[2], x_dims[3],
"Input(X) dims[3] and Input(Grid) dims[2] should be equal.");
ctx->SetOutputDim("Output", x_dims);
ctx->ShareLoD("X", "Output");
......@@ -57,16 +64,15 @@ class GridSampleOp : public framework::OperatorWithKernel {
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_);
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
AddInput("X",
"(Tensor) The input data of GridSampleOp, "
"This is a 4-D tensor with shape of [N, C, H, W]");
AddInput(
......@@ -74,20 +80,20 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor) The input grid of GridSampleOp generated by AffineGridOp, "
"This is a 4-D tensor with shape of [N, H, W, 2] is the concatenation "
"of x and y coordinates with shape [N, H, W] in last dimention");
AddOutput(
"Output",
"(Tensor) Output tensor with shape [N, C, H, W]");
AddOutput("Output", "(Tensor) Output tensor with shape [N, C, H, W]");
AddAttr<bool>(
"use_cudnn",
"(bool, default true) Only used in cudnn kernel, need install cudnn")
.SetDefault(true);
AddComment(R"DOC(
It sample input X by grid gennerate by AffineGridOp. The grid of shape
[N, H, W, 2] is the concatenation of (x, y) coordinates with shape
[N, H, W] each, with x indexing the 4th-D(W) of input feature map and y to
indexng the 3rd-D(H), finally results is the bilinear interpolation value
of 4 nearest corner points.
This operation samples input X by using bilinear interpolation based on
flow field grid, which is usually gennerated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexng the 3rd
dimention (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
......@@ -154,8 +160,8 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_);
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
......
......@@ -19,7 +19,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
......@@ -31,7 +30,6 @@ using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Array3 = Eigen::DSizes<int64_t, 3>;
using Array4 = Eigen::DSizes<int64_t, 4>;
template <typename T>
static inline bool isInBound(T x, T y, T x_max, T y_max) {
if (x < 0 || x > x_max || y < 0 || y > y_max) {
......@@ -40,16 +38,17 @@ static inline bool isInBound(T x, T y, T x_max, T y_max) {
return true;
}
template <typename DeviceContext, typename T>
static void CalcGridLocations(const DeviceContext& ctx, const Tensor& grid,
Tensor* x_w, Tensor* x_e, Tensor* y_n, Tensor* y_s,
Tensor* d_w, Tensor* d_e, Tensor* d_n, Tensor* d_s) {
template <typename T>
static void CalcGridLocations(const platform::CPUDeviceContext& ctx,
const Tensor& grid, Tensor* x_w, Tensor* x_e,
Tensor* y_n, Tensor* y_s, Tensor* d_w,
Tensor* d_e, Tensor* d_n, Tensor* d_s) {
auto& place = *ctx.eigen_device();
const int n = grid.dims()[0];
const int h = grid.dims()[1];
const int w = grid.dims()[2];
const T x_max = static_cast<T> (w - 1);
const T y_max = static_cast<T> (h - 1);
const T x_max = static_cast<T>(w - 1);
const T y_max = static_cast<T>(h - 1);
// split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim
Tensor grid_x, grid_y;
......@@ -117,7 +116,9 @@ static void GetGridPointValue(const Tensor& input, Tensor* output,
for (int l = 0; l < w; l++) {
if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) {
for (int j = 0; j < c; j++) {
output_t(i, j, k, l) = input_t(i, j, (int)round(y_t(i, k, l)), (int)round(x_t(i, k, l)));
output_t(i, j, k, l) =
input_t(i, j, static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l))));
}
}
}
......@@ -126,9 +127,10 @@ static void GetGridPointValue(const Tensor& input, Tensor* output,
}
template <typename T>
static void GatherOutputGradToInputGrad(const Tensor& output_grad, Tensor* input_grad,
const Tensor& x, const Tensor& y,
const Tensor& d1, const Tensor& d2) {
static void GatherOutputGradToInputGrad(const Tensor& output_grad,
Tensor* input_grad, const Tensor& x,
const Tensor& y, const Tensor& d1,
const Tensor& d2) {
const int n = output_grad.dims()[0];
const int c = output_grad.dims()[1];
const int h = output_grad.dims()[2];
......@@ -143,10 +145,11 @@ static void GatherOutputGradToInputGrad(const Tensor& output_grad, Tensor* input
for (int i = 0; i < n; i++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
if(isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) {
if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) {
for (int j = 0; j < c; j++) {
input_grad_t(i, j, (int) y_t(i, k, l), (int) x_t(i, k, l)) +=
output_grad_t(i, j, k ,l) * d1_t(i, k, l) * d2_t(i, k, l);
input_grad_t(i, j, static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l)))) +=
output_grad_t(i, j, k, l) * d1_t(i, k, l) * d2_t(i, k, l);
}
}
}
......@@ -154,8 +157,6 @@ static void GatherOutputGradToInputGrad(const Tensor& output_grad, Tensor* input
}
}
template <typename DeviceContext, typename T>
class GridSampleOpKernel : public framework::OpKernel<T> {
public:
......@@ -172,10 +173,9 @@ class GridSampleOpKernel : public framework::OpKernel<T> {
// calc locations and distances of 4 corner points
Tensor x_w, x_e, y_n, y_s;
Tensor d_w, d_e, d_n, d_s;
CalcGridLocations<DeviceContext, T>(ctx.template device_context<DeviceContext>(),
*grid,
&x_w, &x_e, &y_n, &y_s,
&d_w, &d_e, &d_n, &d_s);
CalcGridLocations<T>(
ctx.template device_context<platform::CPUDeviceContext>(), *grid, &x_w,
&x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s);
auto* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
......@@ -198,22 +198,25 @@ class GridSampleOpKernel : public framework::OpKernel<T> {
auto d_e_t = EigenTensor<T, 3>::From(d_e);
auto d_n_t = EigenTensor<T, 3>::From(d_n);
auto d_s_t = EigenTensor<T, 3>::From(d_s);
auto d_w_scaled_t = d_w_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_e_scaled_t = d_e_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_n_scaled_t = d_n_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_s_scaled_t = d_s_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_w_scaled_t =
d_w_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_e_scaled_t =
d_e_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_n_scaled_t =
d_n_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_s_scaled_t =
d_s_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto v_wn_t = EigenTensor<T, 4>::From(v_wn);
auto v_en_t = EigenTensor<T, 4>::From(v_en);
auto v_ws_t = EigenTensor<T, 4>::From(v_ws);
auto v_es_t = EigenTensor<T, 4>::From(v_es);
auto output_t = EigenTensor<T, 4>::From(*output);
//bilinear interpolaetion by 4 corner points
output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t
+ v_en_t * d_w_scaled_t * d_s_scaled_t
+ v_ws_t * d_e_scaled_t * d_n_scaled_t
+ v_es_t * d_w_scaled_t * d_n_scaled_t;
// bilinear interpolaetion by 4 corner points
output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t +
v_en_t * d_w_scaled_t * d_s_scaled_t +
v_ws_t * d_e_scaled_t * d_n_scaled_t +
v_es_t * d_w_scaled_t * d_n_scaled_t;
}
};
template <typename DeviceContext, typename T>
......@@ -242,16 +245,19 @@ class GridSampleGradOpKernel : public framework::OpKernel<T> {
Tensor x_w, x_e, y_n, y_s;
Tensor d_w, d_e, d_n, d_s;
CalcGridLocations<DeviceContext, T>(ctx.template device_context<DeviceContext>(),
*grid,
&x_w, &x_e, &y_n, &y_s,
&d_w, &d_e, &d_n, &d_s);
CalcGridLocations<T>(
ctx.template device_context<platform::CPUDeviceContext>(), *grid, &x_w,
&x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s);
// gather output grad value to input grad by corner point coords and weight
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_w, y_n, d_e, d_s);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_w, y_s, d_e, d_n);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_e, y_n, d_w, d_s);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_e, y_s, d_w, d_n);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_w, y_n, d_e,
d_s);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_w, y_s, d_e,
d_n);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_e, y_n, d_w,
d_s);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_e, y_s, d_w,
d_n);
// calc 4 corner points value
Tensor v_wn, v_en, v_ws, v_es;
......@@ -281,15 +287,17 @@ class GridSampleGradOpKernel : public framework::OpKernel<T> {
auto grid_grad_x_t = EigenTensor<T, 3>::From(grid_grad_x).setConstant(0.0);
auto grid_grad_y_t = EigenTensor<T, 3>::From(grid_grad_y).setConstant(0.0);
for (int i = 0; i < n; i++) {
for(int j = 0; j < c; j++) {
for(int k = 0; k < h; k++) {
for(int l = 0; l < w; l++) {
grid_grad_x_t(i, k, l) += ((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l)
+ (v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l))
* output_grad_t(i, j, k, l);
grid_grad_y_t(i, k, l) += ((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l)
+ (v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l))
* output_grad_t(i, j, k, l);
for (int j = 0; j < c; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
grid_grad_x_t(i, k, l) +=
((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) +
(v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) *
output_grad_t(i, j, k, l);
grid_grad_y_t(i, k, l) +=
((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) +
(v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) *
output_grad_t(i, j, k, l);
}
}
}
......@@ -308,7 +316,6 @@ class GridSampleGradOpKernel : public framework::OpKernel<T> {
grid_grad_data[2 * i + 1] = grid_grad_y_data[i];
}
}
};
} // namespace operators
......
......@@ -92,7 +92,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnDeriveBNTensorDescriptor); \
__macro(cudnnCreateSpatialTransformerDescriptor); \
__macro(cudnnSetSpatialTransformerNdDescriptor); \
__macro(cudnnDestroySpatialTransformerDescriptor);\
__macro(cudnnDestroySpatialTransformerDescriptor); \
__macro(cudnnSpatialTfGridGeneratorForward); \
__macro(cudnnSpatialTfGridGeneratorBackward); \
__macro(cudnnSpatialTfSamplerForward); \
......
......@@ -7586,11 +7586,13 @@ def hash(input, hash_size, num_hash=1, name=None):
@templatedoc()
def grid_sampler(x, grid, name=None):
"""
It sample input X by grid gennerate by AffineGridOp. The grid of shape
[N, H, W, 2] is the concatenation of (x, y) coordinates with shape
[N, H, W] each, with x indexing the 4th-D(W) of input feature map and y to
indexng the 3rd-D(H), finally results is the bilinear interpolation value
of 4 nearest corner points.
This operation samples input X by using bilinear interpolation based on
flow field grid, which is usually gennerated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexng the 3rd
dimention (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
......@@ -7636,7 +7638,16 @@ def grid_sampler(x, grid, name=None):
name (str, default None): The name of this layer.
Returns:
out(Variable): Output data indices by grid from x of shape [N, C, H, W].
out(Variable): Output of shape [N, C, H, W] data samples input X
using bilnear interpolation based on input grid.
Exmples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[3, 10, 32, 32], dtype='float32')
theta = fluid.layers.data(name='theta', shape=[3, 2, 3], dtype='float32')
grid = fluid.layers.affine_grid(input=theta, size=[3, 10, 32, 32]})
out = fluid.layers.grid_sampler(x=x, grid=grid)
"""
helper = LayerHelper("grid_sampler", **locals())
......@@ -7649,10 +7660,6 @@ def grid_sampler(x, grid, name=None):
out = helper.create_tmp_variable(x.dtype)
ipts = {'X': x, 'Grid': grid}
helper.apppend_op(
type='grid_sampler',
inputs=ipts,
outputs={'Output', out})
helper.apppend_op(type='grid_sampler', inputs=ipts, outputs={'Output', out})
return out
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
......@@ -37,6 +36,7 @@ def AffineGrid(theta, size):
return ret.reshape([n, h, w, 2]).astype("float32")
def getGridPointValue(data, x, y):
data_shape = data.shape
N = data_shape[0]
......@@ -47,13 +47,15 @@ def getGridPointValue(data, x, y):
for i in range(N):
for j in range(H):
for k in range(W):
if y[i, j, k] < 0 or y[i, j, k] > H - 1 or x[i, j, k] < 0 or x[i, j, k] > W - 1:
if y[i, j, k] < 0 or y[i, j, k] > H - 1 or x[i, j, k] < 0 or x[
i, j, k] > W - 1:
out[i, :, j, k] = 0
else:
out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]]
return out
def GridSampler(data, grid):
dims = data.shape
N = dims[0]
......@@ -87,6 +89,7 @@ def GridSampler(data, grid):
out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float32')
return out
class TestGridSamplerOp(OpTest):
def setUp(self):
self.initTestCase()
......@@ -115,5 +118,6 @@ class TestGridSamplerOp(OpTest):
self.grid_shape = (2, 7, 3, 2)
self.theta_shape = (2, 2, 3)
if __name__ == "__main__":
unittest.main()
......@@ -868,13 +868,12 @@ class TestBook(unittest.TestCase):
def test_affine_grid_gen(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[2, 5, 7, 3 ], dtype='float32')
grid = layers.data(name='grid', shape=[2, 5, 7, 2], dtype='float32' )
x = layers.data(name='x', shape=[2, 5, 7, 3], dtype='float32')
grid = layers.data(name='grid', shape=[2, 5, 7, 2], dtype='float32')
out = layers.grid_sampler(x, grid)
self.assertIsNotNone(out)
print(str(program))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册