Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
64f3e3ed
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
64f3e3ed
编写于
11月 03, 2018
作者:
K
Kaipeng Deng
提交者:
GitHub
11月 03, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14069 from heavengate/grid_sampler
Grid sampler operator for spatial transformer network.
上级
8690deb0
decaeb1c
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
872 addition
and
0 deletion
+872
-0
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc
paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc
+132
-0
paddle/fluid/operators/grid_sampler_op.cc
paddle/fluid/operators/grid_sampler_op.cc
+203
-0
paddle/fluid/operators/grid_sampler_op.h
paddle/fluid/operators/grid_sampler_op.h
+322
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+82
-0
python/paddle/fluid/tests/unittests/test_grid_sampler_op.py
python/paddle/fluid/tests/unittests/test_grid_sampler_op.py
+123
-0
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+9
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
64f3e3ed
...
...
@@ -178,6 +178,7 @@ paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], var
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
...
...
paddle/fluid/operators/grid_sampler_cudnn_op.cu.cc
0 → 100644
浏览文件 @
64f3e3ed
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Tensor
;
using
ScopedTensorDescriptor
=
platform
::
ScopedTensorDescriptor
;
using
DataLayout
=
platform
::
DataLayout
;
using
ScopedSpatialTransformerDescriptor
=
platform
::
ScopedSpatialTransformerDescriptor
;
template
<
typename
T
>
using
CudnnDataType
=
platform
::
CudnnDataType
<
T
>
;
template
<
typename
T
>
class
CUDNNGridSampleOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use CUDAPlace"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
grid
=
ctx
.
Input
<
Tensor
>
(
"Grid"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
int
n
=
input
->
dims
()[
0
];
int
c
=
input
->
dims
()[
1
];
int
h
=
input
->
dims
()[
2
];
int
w
=
input
->
dims
()[
3
];
const
int
size
[
4
]
=
{
n
,
c
,
h
,
w
};
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
grid_data
=
grid
->
data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
ScopedSpatialTransformerDescriptor
st_desc
;
cudnnSpatialTransformerDescriptor_t
cudnn_st_desc
=
st_desc
.
descriptor
<
T
>
(
4
,
size
);
ScopedTensorDescriptor
input_desc
;
ScopedTensorDescriptor
output_desc
;
cudnnTensorDescriptor_t
cudnn_input_desc
=
input_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
input
->
dims
()));
cudnnTensorDescriptor_t
cudnn_output_desc
=
output_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
output
->
dims
()));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSpatialTfSamplerForward
(
handle
,
cudnn_st_desc
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_input_desc
,
input_data
,
grid_data
,
CudnnDataType
<
T
>::
kZero
(),
cudnn_output_desc
,
output_data
));
}
};
template
<
typename
T
>
class
CUDNNGridSampleGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use CUDAPlace"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
grid
=
ctx
.
Input
<
Tensor
>
(
"Grid"
);
auto
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
grid_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Grid"
));
auto
output_grad_dims
=
output_grad
->
dims
();
const
int
n
=
output_grad_dims
[
0
];
const
int
c
=
output_grad_dims
[
1
];
const
int
h
=
output_grad_dims
[
2
];
const
int
w
=
output_grad_dims
[
3
];
const
int
size
[
4
]
=
{
n
,
c
,
h
,
w
};
ScopedSpatialTransformerDescriptor
st_dest
;
cudnnSpatialTransformerDescriptor_t
cudnn_st_dest
=
st_dest
.
descriptor
<
T
>
(
4
,
size
);
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
grid_data
=
grid
->
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
->
data
<
T
>
();
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
output_grad_dims
,
ctx
.
GetPlace
());
T
*
grid_grad_data
=
grid_grad
->
mutable_data
<
T
>
({
n
,
h
,
w
,
2
},
ctx
.
GetPlace
());
ScopedTensorDescriptor
input_desc
;
ScopedTensorDescriptor
input_grad_desc
;
ScopedTensorDescriptor
output_grad_desc
;
cudnnTensorDescriptor_t
cudnn_input_desc
=
input_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
input
->
dims
()));
cudnnTensorDescriptor_t
cudnn_input_grad_desc
=
input_grad_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
input_grad
->
dims
()));
cudnnTensorDescriptor_t
cudnn_output_grad_desc
=
output_grad_desc
.
descriptor
<
T
>
(
DataLayout
::
kNCHW
,
framework
::
vectorize2int
(
output_grad
->
dims
()));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSpatialTfSamplerBackward
(
handle
,
cudnn_st_dest
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_input_desc
,
input_data
,
CudnnDataType
<
T
>::
kZero
(),
cudnn_input_grad_desc
,
input_grad_data
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_output_grad_desc
,
output_grad_data
,
grid_data
,
CudnnDataType
<
T
>::
kZero
(),
grid_grad_data
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_KERNEL
(
grid_sampler
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNGridSampleOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNGridSampleOpKernel
<
double
>
);
REGISTER_OP_KERNEL
(
grid_sampler_grad
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNGridSampleGradOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNGridSampleGradOpKernel
<
double
>
);
paddle/fluid/operators/grid_sampler_op.cc
0 → 100644
浏览文件 @
64f3e3ed
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h"
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
GridSampleOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of GridSampleOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grid"
),
"Input(Grid) of GridSampleOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Output"
),
"Output(Output) of GridSampleOp should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
grid_dims
=
ctx
->
GetInputDim
(
"Grid"
);
PADDLE_ENFORCE
(
x_dims
.
size
()
==
4
,
"Input(X) of GridSampleOp should be 4-D Tensor."
);
PADDLE_ENFORCE
(
grid_dims
.
size
()
==
4
,
"Input(Grid) of GridSampleOp should be 4-D Tensor."
);
PADDLE_ENFORCE
(
grid_dims
[
3
]
==
2
,
"Input(Grid) dims[3] should be 2."
);
PADDLE_ENFORCE_EQ
(
grid_dims
[
0
],
x_dims
[
0
],
"Input(X) and Input(Grid) dims[0] should be equal."
);
PADDLE_ENFORCE_EQ
(
grid_dims
[
1
],
x_dims
[
2
],
"Input(X) dims[2] and Input(Grid) dims[1] should be equal."
);
PADDLE_ENFORCE_EQ
(
grid_dims
[
2
],
x_dims
[
3
],
"Input(X) dims[3] and Input(Grid) dims[2] should be equal."
);
ctx
->
SetOutputDim
(
"Output"
,
x_dims
);
ctx
->
ShareLoD
(
"X"
,
"Output"
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
#endif
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
library_
);
}
};
class
GridSampleOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) The input data of GridSampleOp, "
"This is a 4-D tensor with shape of [N, C, H, W]"
);
AddInput
(
"Grid"
,
"(Tensor) The input grid of GridSampleOp generated by AffineGridOp, "
"This is a 4-D tensor with shape of [N, H, W, 2] is the concatenation "
"of x and y coordinates with shape [N, H, W] in last dimention"
);
AddOutput
(
"Output"
,
"(Tensor) Output tensor with shape [N, C, H, W]"
);
AddAttr
<
bool
>
(
"use_cudnn"
,
"(bool, default true) Only used in cudnn kernel, need install cudnn"
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
This operation samples input X by using bilinear interpolation based on
flow field grid, which is usually gennerated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexng the 3rd
dimention (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
)DOC"
);
}
};
class
GridSampleOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
input_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
grid_dims
=
ctx
->
GetInputDim
(
"Grid"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
input_dims
);
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Grid"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Grid"
),
grid_dims
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
#endif
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
library_
);
}
};
class
GridSampleGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
op
=
new
framework
::
OpDesc
();
op
->
SetType
(
"grid_sampler_grad"
);
op
->
SetInput
(
"X"
,
Input
(
"X"
));
op
->
SetInput
(
"Grid"
,
Input
(
"Grid"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Output"
),
OutputGrad
(
"Output"
));
op
->
SetAttrMap
(
Attrs
());
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Grid"
),
InputGrad
(
"Grid"
));
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
grid_sampler
,
ops
::
GridSampleOp
,
ops
::
GridSampleOpMaker
,
ops
::
GridSampleGradMaker
);
REGISTER_OPERATOR
(
grid_sampler_grad
,
ops
::
GridSampleOpGrad
);
REGISTER_OP_CPU_KERNEL
(
grid_sampler
,
ops
::
GridSampleOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GridSampleOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
grid_sampler_grad
,
ops
::
GridSampleGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GridSampleGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/grid_sampler_op.h
0 → 100644
浏览文件 @
64f3e3ed
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
using
Array3
=
Eigen
::
DSizes
<
int64_t
,
3
>
;
using
Array4
=
Eigen
::
DSizes
<
int64_t
,
4
>
;
template
<
typename
T
>
static
inline
bool
isInBound
(
T
x
,
T
y
,
T
x_max
,
T
y_max
)
{
if
(
x
<
0
||
x
>
x_max
||
y
<
0
||
y
>
y_max
)
{
return
false
;
}
return
true
;
}
template
<
typename
T
>
static
void
CalcGridLocations
(
const
platform
::
CPUDeviceContext
&
ctx
,
const
Tensor
&
grid
,
Tensor
*
x_w
,
Tensor
*
x_e
,
Tensor
*
y_n
,
Tensor
*
y_s
,
Tensor
*
d_w
,
Tensor
*
d_e
,
Tensor
*
d_n
,
Tensor
*
d_s
)
{
auto
&
place
=
*
ctx
.
eigen_device
();
const
int
n
=
grid
.
dims
()[
0
];
const
int
h
=
grid
.
dims
()[
1
];
const
int
w
=
grid
.
dims
()[
2
];
const
T
x_max
=
static_cast
<
T
>
(
w
-
1
);
const
T
y_max
=
static_cast
<
T
>
(
h
-
1
);
// split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim
Tensor
grid_x
,
grid_y
;
T
*
grid_x_data
=
grid_x
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
T
*
grid_y_data
=
grid_y
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
const
T
*
grid_data
=
grid
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
n
*
h
*
w
;
i
++
)
{
grid_x_data
[
i
]
=
grid_data
[
2
*
i
];
grid_y_data
[
i
]
=
grid_data
[(
2
*
i
)
+
1
];
}
Tensor
ones
;
ones
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
auto
ones_t
=
EigenTensor
<
T
,
3
>::
From
(
ones
).
setConstant
(
1.0
);
// scale grid to [0, h-1/w-1]
auto
grid_x_t
=
EigenTensor
<
T
,
3
>::
From
(
grid_x
);
auto
grid_y_t
=
EigenTensor
<
T
,
3
>::
From
(
grid_y
);
grid_x_t
.
device
(
place
)
=
0.5
*
((
grid_x_t
+
ones_t
)
*
x_max
);
grid_y_t
.
device
(
place
)
=
0.5
*
((
grid_y_t
+
ones_t
)
*
y_max
);
// calculate coords of 4 corner points
x_w
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
x_e
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
y_n
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
y_s
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
auto
x_w_t
=
EigenTensor
<
T
,
3
>::
From
(
*
x_w
);
auto
x_e_t
=
EigenTensor
<
T
,
3
>::
From
(
*
x_e
);
auto
y_n_t
=
EigenTensor
<
T
,
3
>::
From
(
*
y_n
);
auto
y_s_t
=
EigenTensor
<
T
,
3
>::
From
(
*
y_s
);
x_w_t
.
device
(
place
)
=
grid_x_t
.
floor
();
x_e_t
.
device
(
place
)
=
x_w_t
+
ones_t
;
y_n_t
.
device
(
place
)
=
grid_y_t
.
floor
();
y_s_t
.
device
(
place
)
=
y_n_t
+
ones_t
;
// calculate distances to 4 sides
d_w
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
d_e
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
d_n
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
d_s
->
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
auto
d_w_t
=
EigenTensor
<
T
,
3
>::
From
(
*
d_w
);
auto
d_e_t
=
EigenTensor
<
T
,
3
>::
From
(
*
d_e
);
auto
d_n_t
=
EigenTensor
<
T
,
3
>::
From
(
*
d_n
);
auto
d_s_t
=
EigenTensor
<
T
,
3
>::
From
(
*
d_s
);
d_w_t
.
device
(
place
)
=
grid_x_t
-
x_w_t
;
d_e_t
.
device
(
place
)
=
x_e_t
-
grid_x_t
;
d_n_t
.
device
(
place
)
=
grid_y_t
-
y_n_t
;
d_s_t
.
device
(
place
)
=
y_s_t
-
grid_y_t
;
}
template
<
typename
T
>
static
void
GetGridPointValue
(
const
Tensor
&
input
,
Tensor
*
output
,
const
Tensor
&
x
,
const
Tensor
&
y
)
{
const
int
n
=
input
.
dims
()[
0
];
const
int
c
=
input
.
dims
()[
1
];
const
int
h
=
input
.
dims
()[
2
];
const
int
w
=
input
.
dims
()[
3
];
auto
x_t
=
EigenTensor
<
T
,
3
>::
From
(
x
);
auto
y_t
=
EigenTensor
<
T
,
3
>::
From
(
y
);
auto
output_t
=
EigenTensor
<
T
,
4
>::
From
(
*
output
).
setConstant
((
T
)
0
);
auto
input_t
=
EigenTensor
<
T
,
4
>::
From
(
input
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
k
=
0
;
k
<
h
;
k
++
)
{
for
(
int
l
=
0
;
l
<
w
;
l
++
)
{
if
(
isInBound
(
x_t
(
i
,
k
,
l
),
y_t
(
i
,
k
,
l
),
(
T
)(
w
-
1
),
(
T
)(
h
-
1
)))
{
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
output_t
(
i
,
j
,
k
,
l
)
=
input_t
(
i
,
j
,
static_cast
<
int
>
(
round
(
y_t
(
i
,
k
,
l
))),
static_cast
<
int
>
(
round
(
x_t
(
i
,
k
,
l
))));
}
}
}
}
}
}
template
<
typename
T
>
static
void
GatherOutputGradToInputGrad
(
const
Tensor
&
output_grad
,
Tensor
*
input_grad
,
const
Tensor
&
x
,
const
Tensor
&
y
,
const
Tensor
&
d1
,
const
Tensor
&
d2
)
{
const
int
n
=
output_grad
.
dims
()[
0
];
const
int
c
=
output_grad
.
dims
()[
1
];
const
int
h
=
output_grad
.
dims
()[
2
];
const
int
w
=
output_grad
.
dims
()[
3
];
auto
x_t
=
EigenTensor
<
T
,
3
>::
From
(
x
);
auto
y_t
=
EigenTensor
<
T
,
3
>::
From
(
y
);
auto
d1_t
=
EigenTensor
<
T
,
3
>::
From
(
d1
);
auto
d2_t
=
EigenTensor
<
T
,
3
>::
From
(
d2
);
auto
input_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
*
input_grad
);
auto
output_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
output_grad
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
k
=
0
;
k
<
h
;
k
++
)
{
for
(
int
l
=
0
;
l
<
w
;
l
++
)
{
if
(
isInBound
(
x_t
(
i
,
k
,
l
),
y_t
(
i
,
k
,
l
),
(
T
)(
w
-
1
),
(
T
)(
h
-
1
)))
{
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
input_grad_t
(
i
,
j
,
static_cast
<
int
>
(
round
(
y_t
(
i
,
k
,
l
))),
static_cast
<
int
>
(
round
(
x_t
(
i
,
k
,
l
))))
+=
output_grad_t
(
i
,
j
,
k
,
l
)
*
d1_t
(
i
,
k
,
l
)
*
d2_t
(
i
,
k
,
l
);
}
}
}
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
class
GridSampleOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
grid
=
ctx
.
Input
<
Tensor
>
(
"Grid"
);
const
int
n
=
input
->
dims
()[
0
];
const
int
c
=
input
->
dims
()[
1
];
const
int
h
=
input
->
dims
()[
2
];
const
int
w
=
input
->
dims
()[
3
];
// calc locations and distances of 4 corner points
Tensor
x_w
,
x_e
,
y_n
,
y_s
;
Tensor
d_w
,
d_e
,
d_n
,
d_s
;
CalcGridLocations
<
T
>
(
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>(),
*
grid
,
&
x_w
,
&
x_e
,
&
y_n
,
&
y_s
,
&
d_w
,
&
d_e
,
&
d_n
,
&
d_s
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
output
->
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
output
,
static_cast
<
T
>
(
0
));
// calc 4 corner points value
Tensor
v_wn
,
v_en
,
v_ws
,
v_es
;
v_wn
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_en
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_ws
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_es
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
GetGridPointValue
<
T
>
(
*
input
,
&
v_wn
,
x_w
,
y_n
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_en
,
x_e
,
y_n
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_ws
,
x_w
,
y_s
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_es
,
x_e
,
y_s
);
auto
d_w_t
=
EigenTensor
<
T
,
3
>::
From
(
d_w
);
auto
d_e_t
=
EigenTensor
<
T
,
3
>::
From
(
d_e
);
auto
d_n_t
=
EigenTensor
<
T
,
3
>::
From
(
d_n
);
auto
d_s_t
=
EigenTensor
<
T
,
3
>::
From
(
d_s
);
auto
d_w_scaled_t
=
d_w_t
.
reshape
(
Array4
(
n
,
1
,
h
,
w
)).
broadcast
(
Array4
(
1
,
c
,
1
,
1
));
auto
d_e_scaled_t
=
d_e_t
.
reshape
(
Array4
(
n
,
1
,
h
,
w
)).
broadcast
(
Array4
(
1
,
c
,
1
,
1
));
auto
d_n_scaled_t
=
d_n_t
.
reshape
(
Array4
(
n
,
1
,
h
,
w
)).
broadcast
(
Array4
(
1
,
c
,
1
,
1
));
auto
d_s_scaled_t
=
d_s_t
.
reshape
(
Array4
(
n
,
1
,
h
,
w
)).
broadcast
(
Array4
(
1
,
c
,
1
,
1
));
auto
v_wn_t
=
EigenTensor
<
T
,
4
>::
From
(
v_wn
);
auto
v_en_t
=
EigenTensor
<
T
,
4
>::
From
(
v_en
);
auto
v_ws_t
=
EigenTensor
<
T
,
4
>::
From
(
v_ws
);
auto
v_es_t
=
EigenTensor
<
T
,
4
>::
From
(
v_es
);
auto
output_t
=
EigenTensor
<
T
,
4
>::
From
(
*
output
);
// bilinear interpolaetion by 4 corner points
output_t
.
device
(
place
)
=
v_wn_t
*
d_e_scaled_t
*
d_s_scaled_t
+
v_en_t
*
d_w_scaled_t
*
d_s_scaled_t
+
v_ws_t
*
d_e_scaled_t
*
d_n_scaled_t
+
v_es_t
*
d_w_scaled_t
*
d_n_scaled_t
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
GridSampleGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
grid
=
ctx
.
Input
<
Tensor
>
(
"Grid"
);
auto
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
const
int
n
=
input
->
dims
()[
0
];
const
int
c
=
input
->
dims
()[
1
];
const
int
h
=
input
->
dims
()[
2
];
const
int
w
=
input
->
dims
()[
3
];
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
input_grad
->
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
input_grad
,
static_cast
<
T
>
(
0
));
auto
*
grid_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Grid"
));
grid_grad
->
mutable_data
<
T
>
({
n
,
h
,
w
,
2
},
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
grid_grad
,
static_cast
<
T
>
(
0
));
Tensor
x_w
,
x_e
,
y_n
,
y_s
;
Tensor
d_w
,
d_e
,
d_n
,
d_s
;
CalcGridLocations
<
T
>
(
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>(),
*
grid
,
&
x_w
,
&
x_e
,
&
y_n
,
&
y_s
,
&
d_w
,
&
d_e
,
&
d_n
,
&
d_s
);
// gather output grad value to input grad by corner point coords and weight
GatherOutputGradToInputGrad
<
T
>
(
*
output_grad
,
input_grad
,
x_w
,
y_n
,
d_e
,
d_s
);
GatherOutputGradToInputGrad
<
T
>
(
*
output_grad
,
input_grad
,
x_w
,
y_s
,
d_e
,
d_n
);
GatherOutputGradToInputGrad
<
T
>
(
*
output_grad
,
input_grad
,
x_e
,
y_n
,
d_w
,
d_s
);
GatherOutputGradToInputGrad
<
T
>
(
*
output_grad
,
input_grad
,
x_e
,
y_s
,
d_w
,
d_n
);
// calc 4 corner points value
Tensor
v_wn
,
v_en
,
v_ws
,
v_es
;
v_wn
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_en
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_ws
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
v_es
.
mutable_data
<
T
>
({
n
,
c
,
h
,
w
},
ctx
.
GetPlace
());
GetGridPointValue
<
T
>
(
*
input
,
&
v_wn
,
x_w
,
y_n
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_en
,
x_e
,
y_n
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_ws
,
x_w
,
y_s
);
GetGridPointValue
<
T
>
(
*
input
,
&
v_es
,
x_e
,
y_s
);
auto
v_wn_t
=
EigenTensor
<
T
,
4
>::
From
(
v_wn
);
auto
v_en_t
=
EigenTensor
<
T
,
4
>::
From
(
v_en
);
auto
v_ws_t
=
EigenTensor
<
T
,
4
>::
From
(
v_ws
);
auto
v_es_t
=
EigenTensor
<
T
,
4
>::
From
(
v_es
);
auto
d_w_t
=
EigenTensor
<
T
,
3
>::
From
(
d_w
);
auto
d_e_t
=
EigenTensor
<
T
,
3
>::
From
(
d_e
);
auto
d_n_t
=
EigenTensor
<
T
,
3
>::
From
(
d_n
);
auto
d_s_t
=
EigenTensor
<
T
,
3
>::
From
(
d_s
);
auto
output_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
*
output_grad
);
Tensor
grid_grad_x
,
grid_grad_y
;
grid_grad_x
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
grid_grad_y
.
mutable_data
<
T
>
({
n
,
h
,
w
},
ctx
.
GetPlace
());
auto
grid_grad_x_t
=
EigenTensor
<
T
,
3
>::
From
(
grid_grad_x
).
setConstant
(
0.0
);
auto
grid_grad_y_t
=
EigenTensor
<
T
,
3
>::
From
(
grid_grad_y
).
setConstant
(
0.0
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
for
(
int
k
=
0
;
k
<
h
;
k
++
)
{
for
(
int
l
=
0
;
l
<
w
;
l
++
)
{
grid_grad_x_t
(
i
,
k
,
l
)
+=
((
v_en_t
(
i
,
j
,
k
,
l
)
-
v_wn_t
(
i
,
j
,
k
,
l
))
*
d_s_t
(
i
,
k
,
l
)
+
(
v_es_t
(
i
,
j
,
k
,
l
)
-
v_ws_t
(
i
,
j
,
k
,
l
))
*
d_n_t
(
i
,
k
,
l
))
*
output_grad_t
(
i
,
j
,
k
,
l
);
grid_grad_y_t
(
i
,
k
,
l
)
+=
((
v_ws_t
(
i
,
j
,
k
,
l
)
-
v_wn_t
(
i
,
j
,
k
,
l
))
*
d_e_t
(
i
,
k
,
l
)
+
(
v_es_t
(
i
,
j
,
k
,
l
)
-
v_en_t
(
i
,
j
,
k
,
l
))
*
d_w_t
(
i
,
k
,
l
))
*
output_grad_t
(
i
,
j
,
k
,
l
);
}
}
}
}
const
T
x_max
=
static_cast
<
T
>
(
w
-
1
);
const
T
y_max
=
static_cast
<
T
>
(
h
-
1
);
grid_grad_x_t
=
grid_grad_x_t
*
(
x_max
/
(
T
)
2
);
grid_grad_y_t
=
grid_grad_y_t
*
(
y_max
/
(
T
)
2
);
// gather grid_grad [x, y] in 3rd Dim
T
*
grid_grad_data
=
grid_grad
->
data
<
T
>
();
T
*
grid_grad_x_data
=
grid_grad_x
.
data
<
T
>
();
T
*
grid_grad_y_data
=
grid_grad_y
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
n
*
h
*
w
;
i
++
)
{
grid_grad_data
[
2
*
i
]
=
grid_grad_x_data
[
i
];
grid_grad_data
[
2
*
i
+
1
]
=
grid_grad_y_data
[
i
];
}
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/layers/nn.py
浏览文件 @
64f3e3ed
...
...
@@ -158,6 +158,7 @@ __all__ = [
'sequence_reverse'
,
'affine_channel'
,
'hash'
,
'grid_sampler'
,
'log_loss'
,
'add_position_encoding'
,
]
...
...
@@ -7801,6 +7802,87 @@ def hash(input, hash_size, num_hash=1, name=None):
return
out
@
templatedoc
()
def
grid_sampler
(
x
,
grid
,
name
=
None
):
"""
This operation samples input X by using bilinear interpolation based on
flow field grid, which is usually gennerated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexng the 3rd
dimention (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
Args:
x(Variable): Input data of shape [N, C, H, W].
grid(Variable): Input grid tensor of shape [N, H, W, 2].
name (str, default None): The name of this layer.
Returns:
out(Variable): Output of shape [N, C, H, W] data samples input X
using bilnear interpolation based on input grid.
Exmples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[3, 10, 32, 32], dtype='float32')
theta = fluid.layers.data(name='theta', shape=[3, 2, 3], dtype='float32')
grid = fluid.layers.affine_grid(input=theta, size=[3, 10, 32, 32]})
out = fluid.layers.grid_sampler(x=x, grid=grid)
"""
helper
=
LayerHelper
(
"grid_sampler"
,
**
locals
())
if
not
isinstance
(
x
,
Variable
):
return
ValueError
(
"The x should be a Variable"
)
if
not
isinstance
(
grid
,
Variable
):
return
ValueError
(
"The grid should be a Variable"
)
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
ipts
=
{
'X'
:
x
,
'Grid'
:
grid
}
helper
.
append_op
(
type
=
'grid_sampler'
,
inputs
=
ipts
,
outputs
=
{
'Output'
:
out
})
return
out
def
log_loss
(
input
,
label
,
epsilon
=
1e-4
,
name
=
None
):
"""
**Negative Log Loss Layer**
...
...
python/paddle/fluid/tests/unittests/test_grid_sampler_op.py
0 → 100644
浏览文件 @
64f3e3ed
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
def
AffineGrid
(
theta
,
size
):
n
=
size
[
0
]
h
=
size
[
2
]
w
=
size
[
3
]
h_idx
=
np
.
repeat
(
np
.
linspace
(
-
1
,
1
,
h
)[
np
.
newaxis
,
:],
w
,
axis
=
0
).
T
[:,
:,
np
.
newaxis
]
w_idx
=
np
.
repeat
(
np
.
linspace
(
-
1
,
1
,
w
)[
np
.
newaxis
,
:],
h
,
axis
=
0
)[:,
:,
np
.
newaxis
]
grid
=
np
.
concatenate
(
[
w_idx
,
h_idx
,
np
.
ones
([
h
,
w
,
1
])],
axis
=
2
)
# h * w * 3
grid
=
np
.
repeat
(
grid
[
np
.
newaxis
,
:],
size
[
0
],
axis
=
0
)
# n * h * w *3
ret
=
np
.
zeros
([
n
,
h
*
w
,
2
])
theta
=
theta
.
transpose
([
0
,
2
,
1
])
for
i
in
range
(
len
(
theta
)):
ret
[
i
]
=
np
.
dot
(
grid
[
i
].
reshape
([
h
*
w
,
3
]),
theta
[
i
])
return
ret
.
reshape
([
n
,
h
,
w
,
2
]).
astype
(
"float32"
)
def
getGridPointValue
(
data
,
x
,
y
):
data_shape
=
data
.
shape
N
=
data_shape
[
0
]
H
=
data_shape
[
2
]
W
=
data_shape
[
3
]
out
=
np
.
zeros
(
data_shape
,
dtype
=
'float'
)
for
i
in
range
(
N
):
for
j
in
range
(
H
):
for
k
in
range
(
W
):
if
y
[
i
,
j
,
k
]
<
0
or
y
[
i
,
j
,
k
]
>
H
-
1
or
x
[
i
,
j
,
k
]
<
0
or
x
[
i
,
j
,
k
]
>
W
-
1
:
out
[
i
,
:,
j
,
k
]
=
0
else
:
out
[
i
,
:,
j
,
k
]
=
data
[
i
,
:,
y
[
i
,
j
,
k
],
x
[
i
,
j
,
k
]]
return
out
def
GridSampler
(
data
,
grid
):
dims
=
data
.
shape
N
=
dims
[
0
]
C
=
dims
[
1
]
H
=
dims
[
2
]
W
=
dims
[
3
]
x
=
grid
[:,
:,
:,
0
]
y
=
grid
[:,
:,
:,
1
]
y_max
=
H
-
1
x_max
=
W
-
1
x
=
0.5
*
((
x
.
astype
(
'float32'
)
+
1.0
)
*
x_max
)
y
=
0.5
*
((
y
.
astype
(
'float32'
)
+
1.0
)
*
y_max
)
x0
=
np
.
floor
(
x
).
astype
(
'int32'
)
x1
=
x0
+
1
y0
=
np
.
floor
(
y
).
astype
(
'int32'
)
y1
=
y0
+
1
wa
=
np
.
tile
(((
x1
-
x
)
*
(
y1
-
y
)).
reshape
((
N
,
1
,
H
,
W
)),
(
1
,
C
,
1
,
1
))
wb
=
np
.
tile
(((
x1
-
x
)
*
(
y
-
y0
)).
reshape
((
N
,
1
,
H
,
W
)),
(
1
,
C
,
1
,
1
))
wc
=
np
.
tile
(((
x
-
x0
)
*
(
y1
-
y
)).
reshape
((
N
,
1
,
H
,
W
)),
(
1
,
C
,
1
,
1
))
wd
=
np
.
tile
(((
x
-
x0
)
*
(
y
-
y0
)).
reshape
((
N
,
1
,
H
,
W
)),
(
1
,
C
,
1
,
1
))
va
=
getGridPointValue
(
data
,
x0
,
y0
)
vb
=
getGridPointValue
(
data
,
x0
,
y1
)
vc
=
getGridPointValue
(
data
,
x1
,
y0
)
vd
=
getGridPointValue
(
data
,
x1
,
y1
)
out
=
(
wa
*
va
+
wb
*
vb
+
wc
*
vc
+
wd
*
vd
).
astype
(
'float32'
)
return
out
class
TestGridSamplerOp
(
OpTest
):
def
setUp
(
self
):
self
.
initTestCase
()
self
.
op_type
=
'grid_sampler'
x
=
np
.
random
.
randint
(
0
,
255
,
self
.
x_shape
).
astype
(
'float32'
)
theta
=
np
.
zeros
(
self
.
theta_shape
).
astype
(
'float32'
)
for
i
in
range
(
self
.
theta_shape
[
0
]):
for
j
in
range
(
2
):
for
k
in
range
(
3
):
theta
[
i
,
j
,
k
]
=
np
.
random
.
rand
(
1
)[
0
]
grid
=
AffineGrid
(
theta
,
self
.
x_shape
)
self
.
inputs
=
{
'X'
:
x
,
'Grid'
:
grid
}
self
.
attrs
=
{
'use_cudnn'
:
True
}
self
.
outputs
=
{
'Output'
:
GridSampler
(
x
,
grid
)}
def
test_check_output
(
self
):
self
.
check_output
(
atol
=
1e-3
)
def
test_check_grad_normal
(
self
):
self
.
check_grad
([
'X'
,
'Grid'
],
'Output'
,
max_relative_error
=
0.61
)
def
initTestCase
(
self
):
self
.
x_shape
=
(
2
,
5
,
7
,
3
)
self
.
grid_shape
=
(
2
,
7
,
3
,
2
)
self
.
theta_shape
=
(
2
,
2
,
3
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
64f3e3ed
...
...
@@ -865,6 +865,15 @@ class TestBook(unittest.TestCase):
self
.
assertIsNotNone
(
out
)
print
(
str
(
program
))
def
test_grid_sampler
(
self
):
program
=
Program
()
with
program_guard
(
program
):
x
=
layers
.
data
(
name
=
'x'
,
shape
=
[
3
,
5
,
7
],
dtype
=
'float32'
)
grid
=
layers
.
data
(
name
=
'grid'
,
shape
=
[
5
,
7
,
2
],
dtype
=
'float32'
)
out
=
layers
.
grid_sampler
(
x
,
grid
)
self
.
assertIsNotNone
(
out
)
print
(
str
(
program
))
def
test_affine_grid
(
self
):
program
=
Program
()
with
program_guard
(
program
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录