提交 593e1b18 编写于 作者: D dengkaipeng 提交者: dengkaipeng

fix some bugs and add some doc for GridSampleOp

上级 0bb0e0c1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. Licensed under the Apache License, Version 2.0 (the "License");
You may obtain a copy of the License at you may not use this file except in compliance with the License.
http://www.apache.org/licenses/LICENSE-2.0 You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, http://www.apache.org/licenses/LICENSE-2.0
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and Unless required by applicable law or agreed to in writing, software
limitations under the License. */ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -67,23 +67,66 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -67,23 +67,66 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput( AddInput(
"X", "X",
"(Tensor) The input tensor of GridSampleOp, " "(Tensor) The input data of GridSampleOp, "
"This is a 4-D tensor with shape of [N, C, H, W]"); "This is a 4-D tensor with shape of [N, C, H, W]");
AddInput( AddInput(
"Grid", "Grid",
"(Tensor) The output of AffineGridOp, " "(Tensor) The input grid of GridSampleOp generated by AffineGridOp, "
"This is a 4-D tensor with shape of [N, H, W, 2]"); "This is a 4-D tensor with shape of [N, H, W, 2] is the concatenation "
"of x and y coordinates with shape [N, H, W] in last dimention");
AddOutput( AddOutput(
"Output", "Output",
"(Tensor) Output tensor with shape [N, C, H, W]"); "(Tensor) Output tensor with shape [N, C, H, W]");
AddAttr<bool>( AddAttr<bool>(
"use_cudnn", "use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn") "(bool, default true) Only used in cudnn kernel, need install cudnn")
.SetDefault(true); .SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
It sample input X by grid gennerate by AffineGridOp. It sample input X by grid gennerate by AffineGridOp. The grid of shape
)DOC"); [N, H, W, 2] is the concatenation of (x, y) coordinates with shape
[N, H, W] each, with x indexing the 4th-D(W) of input feature map and y to
indexng the 3rd-D(H), finally results is the bilinear interpolation value
of 4 nearest corner points.
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
)DOC");
} }
}; };
...@@ -91,7 +134,14 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { ...@@ -91,7 +134,14 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
//TO DO auto input_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
}
if (ctx->HasOutput(framework::GradVarName("Grid"))) {
ctx->SetOutputDim(framework::GradVarName("Grid"), grid_dims);
}
} }
protected: protected:
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -33,7 +33,7 @@ using Array4 = Eigen::DSizes<int64_t, 4>; ...@@ -33,7 +33,7 @@ using Array4 = Eigen::DSizes<int64_t, 4>;
template <typename T> template <typename T>
inline bool isInBound(T x, T y, T x_max, T y_max) { static inline bool isInBound(T x, T y, T x_max, T y_max) {
if (x < 0 || x > x_max || y < 0 || y > y_max) { if (x < 0 || x > x_max || y < 0 || y > y_max) {
return false; return false;
} }
...@@ -41,10 +41,10 @@ inline bool isInBound(T x, T y, T x_max, T y_max) { ...@@ -41,10 +41,10 @@ inline bool isInBound(T x, T y, T x_max, T y_max) {
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void CalcGridLocations(const framework::ExecutionContext& ctx, const Tensor& grid, static void CalcGridLocations(const DeviceContext& ctx, const Tensor& grid,
Tensor* x_w, Tensor* x_e, Tensor* y_n, Tensor* y_s, Tensor* x_w, Tensor* x_e, Tensor* y_n, Tensor* y_s,
Tensor* d_w, Tensor* d_e, Tensor* d_n, Tensor* d_s) { Tensor* d_w, Tensor* d_e, Tensor* d_n, Tensor* d_s) {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.eigen_device();
const int n = grid.dims()[0]; const int n = grid.dims()[0];
const int h = grid.dims()[1]; const int h = grid.dims()[1];
const int w = grid.dims()[2]; const int w = grid.dims()[2];
...@@ -71,6 +71,7 @@ void CalcGridLocations(const framework::ExecutionContext& ctx, const Tensor& gri ...@@ -71,6 +71,7 @@ void CalcGridLocations(const framework::ExecutionContext& ctx, const Tensor& gri
grid_x_t.device(place) = 0.5 * ((grid_x_t + ones_t) * x_max); grid_x_t.device(place) = 0.5 * ((grid_x_t + ones_t) * x_max);
grid_y_t.device(place) = 0.5 * ((grid_y_t + ones_t) * y_max); grid_y_t.device(place) = 0.5 * ((grid_y_t + ones_t) * y_max);
// calculate coords of 4 corner points
x_w->mutable_data<T>({n, h, w}, ctx.GetPlace()); x_w->mutable_data<T>({n, h, w}, ctx.GetPlace());
x_e->mutable_data<T>({n, h, w}, ctx.GetPlace()); x_e->mutable_data<T>({n, h, w}, ctx.GetPlace());
y_n->mutable_data<T>({n, h, w}, ctx.GetPlace()); y_n->mutable_data<T>({n, h, w}, ctx.GetPlace());
...@@ -84,6 +85,7 @@ void CalcGridLocations(const framework::ExecutionContext& ctx, const Tensor& gri ...@@ -84,6 +85,7 @@ void CalcGridLocations(const framework::ExecutionContext& ctx, const Tensor& gri
y_n_t.device(place) = grid_y_t.floor(); y_n_t.device(place) = grid_y_t.floor();
y_s_t.device(place) = y_n_t + ones_t; y_s_t.device(place) = y_n_t + ones_t;
// calculate distances to 4 sides
d_w->mutable_data<T>({n, h, w}, ctx.GetPlace()); d_w->mutable_data<T>({n, h, w}, ctx.GetPlace());
d_e->mutable_data<T>({n, h, w}, ctx.GetPlace()); d_e->mutable_data<T>({n, h, w}, ctx.GetPlace());
d_n->mutable_data<T>({n, h, w}, ctx.GetPlace()); d_n->mutable_data<T>({n, h, w}, ctx.GetPlace());
...@@ -99,7 +101,7 @@ void CalcGridLocations(const framework::ExecutionContext& ctx, const Tensor& gri ...@@ -99,7 +101,7 @@ void CalcGridLocations(const framework::ExecutionContext& ctx, const Tensor& gri
} }
template <typename T> template <typename T>
void GetGridPointValue(const Tensor& input, Tensor* output, static void GetGridPointValue(const Tensor& input, Tensor* output,
const Tensor& x, const Tensor& y) { const Tensor& x, const Tensor& y) {
const int n = input.dims()[0]; const int n = input.dims()[0];
const int c = input.dims()[1]; const int c = input.dims()[1];
...@@ -124,7 +126,7 @@ void GetGridPointValue(const Tensor& input, Tensor* output, ...@@ -124,7 +126,7 @@ void GetGridPointValue(const Tensor& input, Tensor* output,
} }
template <typename T> template <typename T>
void GatherOutputGradToInputGrad(const Tensor& output_grad, Tensor* input_grad, static void GatherOutputGradToInputGrad(const Tensor& output_grad, Tensor* input_grad,
const Tensor& x, const Tensor& y, const Tensor& x, const Tensor& y,
const Tensor& d1, const Tensor& d2) { const Tensor& d1, const Tensor& d2) {
const int n = output_grad.dims()[0]; const int n = output_grad.dims()[0];
...@@ -170,9 +172,10 @@ class GridSampleOpKernel : public framework::OpKernel<T> { ...@@ -170,9 +172,10 @@ class GridSampleOpKernel : public framework::OpKernel<T> {
// calc locations and distances of 4 corner points // calc locations and distances of 4 corner points
Tensor x_w, x_e, y_n, y_s; Tensor x_w, x_e, y_n, y_s;
Tensor d_w, d_e, d_n, d_s; Tensor d_w, d_e, d_n, d_s;
CalcGridLocations<DeviceContext, T>(ctx, *grid, CalcGridLocations<DeviceContext, T>(ctx.template device_context<DeviceContext>(),
&x_w, &x_e, &y_n, &y_s, *grid,
&d_w, &d_e, &d_n, &d_s); &x_w, &x_e, &y_n, &y_s,
&d_w, &d_e, &d_n, &d_s);
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>({n, c, h, w}, ctx.GetPlace()); output->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
...@@ -239,9 +242,10 @@ class GridSampleGradOpKernel : public framework::OpKernel<T> { ...@@ -239,9 +242,10 @@ class GridSampleGradOpKernel : public framework::OpKernel<T> {
Tensor x_w, x_e, y_n, y_s; Tensor x_w, x_e, y_n, y_s;
Tensor d_w, d_e, d_n, d_s; Tensor d_w, d_e, d_n, d_s;
CalcGridLocations<DeviceContext, T>(ctx, *grid, CalcGridLocations<DeviceContext, T>(ctx.template device_context<DeviceContext>(),
&x_w, &x_e, &y_n, &y_s, *grid,
&d_w, &d_e, &d_n, &d_s); &x_w, &x_e, &y_n, &y_s,
&d_w, &d_e, &d_n, &d_s);
// gather output grad value to input grad by corner point coords and weight // gather output grad value to input grad by corner point coords and weight
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_w, y_n, d_e, d_s); GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_w, y_n, d_e, d_s);
......
...@@ -7584,17 +7584,59 @@ def hash(input, hash_size, num_hash=1, name=None): ...@@ -7584,17 +7584,59 @@ def hash(input, hash_size, num_hash=1, name=None):
@templatedoc() @templatedoc()
def grid_sampler(x, grid): def grid_sampler(x, grid, name=None):
""" """
It sample data from input x by the given grid, insert data of each It sample input X by grid gennerate by AffineGridOp. The grid of shape
point by bilinear interp. [N, H, W, 2] is the concatenation of (x, y) coordinates with shape
[N, H, W] each, with x indexing the 4th-D(W) of input feature map and y to
indexng the 3rd-D(H), finally results is the bilinear interpolation value
of 4 nearest corner points.
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
Args: Args:
x(Variable): Input data of shape [N, H, W, C] x(Variable): Input data of shape [N, C, H, W].
grid(Variable): Input grid tensor of shape [N, H, W, 2] grid(Variable): Input grid tensor of shape [N, H, W, 2].
name (str, default None): The name of this layer.
Returns: Returns:
out(Variable): Output data indices by grid from x of shape [N, H, W, C] out(Variable): Output data indices by grid from x of shape [N, C, H, W].
""" """
helper = LayerHelper("grid_sampler", **locals()) helper = LayerHelper("grid_sampler", **locals())
...@@ -7606,13 +7648,11 @@ def grid_sampler(x, grid): ...@@ -7606,13 +7648,11 @@ def grid_sampler(x, grid):
out = helper.create_tmp_variable(x.dtype) out = helper.create_tmp_variable(x.dtype)
ipts = {'X': x, 'Grid': grid} ipts = {'X': x, 'Grid': grid}
attrs = {}
helper.apppend_op( helper.apppend_op(
type='grid_sampler', type='grid_sampler',
inputs=ipts, inputs=ipts,
outputs={'Output', out}, outputs={'Output', out})
attrs = None if len(attrs) == 0 else attrs)
return 0 return out
...@@ -35,7 +35,6 @@ def AffineGrid(theta, size): ...@@ -35,7 +35,6 @@ def AffineGrid(theta, size):
for i in range(len(theta)): for i in range(len(theta)):
ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i]) ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i])
# print ret.reshape([n, h * w, 2]).astype("float32")
return ret.reshape([n, h, w, 2]).astype("float32") return ret.reshape([n, h, w, 2]).astype("float32")
def getGridPointValue(data, x, y): def getGridPointValue(data, x, y):
...@@ -104,13 +103,12 @@ class TestGridSamplerOp(OpTest): ...@@ -104,13 +103,12 @@ class TestGridSamplerOp(OpTest):
self.inputs = {'X': x, 'Grid': grid} self.inputs = {'X': x, 'Grid': grid}
self.attrs = {'use_cudnn': True} self.attrs = {'use_cudnn': True}
self.outputs = {'Output': GridSampler(x, grid)} self.outputs = {'Output': GridSampler(x, grid)}
# print self.outputs
def test_check_output(self): def test_check_output(self):
self.check_output(atol=1e-3) self.check_output(atol=1e-3)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X', 'Grid'], 'Output', max_relative_error=0.6) self.check_grad(['X', 'Grid'], 'Output', max_relative_error=0.61)
def initTestCase(self): def initTestCase(self):
self.x_shape = (2, 5, 7, 3) self.x_shape = (2, 5, 7, 3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册