未验证 提交 eb019760 编写于 作者: W whs 提交者: GitHub

[2.0 API]Add checker in grid_sample_grad op (#27126)

上级 a28ae86e
...@@ -115,7 +115,7 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -115,7 +115,7 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::string>( AddAttr<std::string>(
"padding_mode", "padding_mode",
"(bool, default true) The padding method used when source" "(bool, default true) The padding method used when source"
"index is out of input images. It can be 'zeros', 'reflect' and " "index is out of input images. It can be 'zeros', 'reflection' and "
"'border'.") "'border'.")
.SetDefault("zeros"); .SetDefault("zeros");
...@@ -174,6 +174,10 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { ...@@ -174,6 +174,10 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "grid_sampler");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Grid")), "Output",
framework::GradVarName("Grid"), "grid_sampler");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid"); auto grid_dims = ctx->GetInputDim("Grid");
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
......
...@@ -268,7 +268,7 @@ class GridSampleOpCUDAKernel : public framework::OpKernel<T> { ...@@ -268,7 +268,7 @@ class GridSampleOpCUDAKernel : public framework::OpKernel<T> {
Mode mode; Mode mode;
if (padding_mode_s == "border") { if (padding_mode_s == "border") {
padding_mode = PaddingMode::border; padding_mode = PaddingMode::border;
} else if (padding_mode_s == "reflect") { } else if (padding_mode_s == "reflection") {
padding_mode = PaddingMode::reflect; padding_mode = PaddingMode::reflect;
} else { } else {
padding_mode = PaddingMode::zeros; padding_mode = PaddingMode::zeros;
...@@ -432,7 +432,7 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -432,7 +432,7 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
Mode mode; Mode mode;
if (padding_mode_s == "border") { if (padding_mode_s == "border") {
padding_mode = PaddingMode::border; padding_mode = PaddingMode::border;
} else if (padding_mode_s == "reflect") { } else if (padding_mode_s == "reflection") {
padding_mode = PaddingMode::reflect; padding_mode = PaddingMode::reflect;
} else { } else {
padding_mode = PaddingMode::zeros; padding_mode = PaddingMode::zeros;
......
...@@ -76,7 +76,7 @@ static inline void clip(const platform::CPUDeviceContext& ctx, ...@@ -76,7 +76,7 @@ static inline void clip(const platform::CPUDeviceContext& ctx,
if (padding_mode == "border") { if (padding_mode == "border") {
grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast<T>(0)) grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val)); .cwiseMin(static_cast<T>(max_val));
} else if (padding_mode == "reflect") { } else if (padding_mode == "reflection") {
if (align_corners) { if (align_corners) {
auto double_range = static_cast<T>(max_val * 2); auto double_range = static_cast<T>(max_val * 2);
auto grid_abs = grid_slice_t.abs(); auto grid_abs = grid_slice_t.abs();
...@@ -117,7 +117,7 @@ static inline void clipWithMask(const platform::CPUDeviceContext& ctx, ...@@ -117,7 +117,7 @@ static inline void clipWithMask(const platform::CPUDeviceContext& ctx,
auto in_bound = (res == grid_slice_t); auto in_bound = (res == grid_slice_t);
grid_scale_t.device(place) = grid_scale_t * in_bound.template cast<T>(); grid_scale_t.device(place) = grid_scale_t * in_bound.template cast<T>();
grid_slice_t.device(place) = res; grid_slice_t.device(place) = res;
} else if (padding_mode == "reflect") { } else if (padding_mode == "reflection") {
if (align_corners) { if (align_corners) {
auto double_range = static_cast<T>(max_val * 2); auto double_range = static_cast<T>(max_val * 2);
auto is_neg = (grid_slice_t < static_cast<T>(0)); auto is_neg = (grid_slice_t < static_cast<T>(0));
......
...@@ -100,7 +100,7 @@ def add_cases(suite): ...@@ -100,7 +100,7 @@ def add_cases(suite):
GridSampleTestCase( GridSampleTestCase(
methodName='runTest', methodName='runTest',
mode='bilinear', mode='bilinear',
padding_mode='reflect', padding_mode='reflection',
align_corners=True)) align_corners=True))
suite.addTest( suite.addTest(
GridSampleTestCase( GridSampleTestCase(
......
...@@ -73,7 +73,7 @@ def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode): ...@@ -73,7 +73,7 @@ def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode):
if padding_mode == "border": if padding_mode == "border":
grid_slice = clip(grid_slice, 0, max_val) grid_slice = clip(grid_slice, 0, max_val)
elif padding_mode == "reflect": elif padding_mode == "reflection":
double_range = 2 * max_val if align_corners else (max_val + 1) * 2 double_range = 2 * max_val if align_corners else (max_val + 1) * 2
grid_abs = np.abs(grid_slice) if align_corners else np.abs(grid_slice + grid_abs = np.abs(grid_slice) if align_corners else np.abs(grid_slice +
0.5) 0.5)
...@@ -211,7 +211,7 @@ class Case2(TestGridSamplerOp): ...@@ -211,7 +211,7 @@ class Case2(TestGridSamplerOp):
self.grid_shape = (2, 8, 9, 2) self.grid_shape = (2, 8, 9, 2)
self.theta_shape = (2, 2, 3) self.theta_shape = (2, 2, 3)
self.align_corners = False self.align_corners = False
self.padding_mode = "reflect" self.padding_mode = "reflection"
self.mode = "bilinear" self.mode = "bilinear"
...@@ -221,7 +221,7 @@ class Case3(TestGridSamplerOp): ...@@ -221,7 +221,7 @@ class Case3(TestGridSamplerOp):
self.grid_shape = (2, 8, 9, 2) self.grid_shape = (2, 8, 9, 2)
self.theta_shape = (2, 2, 3) self.theta_shape = (2, 2, 3)
self.align_corners = True self.align_corners = True
self.padding_mode = "reflect" self.padding_mode = "reflection"
self.mode = "bilinear" self.mode = "bilinear"
...@@ -231,7 +231,7 @@ class Case4(TestGridSamplerOp): ...@@ -231,7 +231,7 @@ class Case4(TestGridSamplerOp):
self.grid_shape = (2, 8, 9, 2) self.grid_shape = (2, 8, 9, 2)
self.theta_shape = (2, 2, 3) self.theta_shape = (2, 2, 3)
self.align_corners = False self.align_corners = False
self.padding_mode = "reflect" self.padding_mode = "reflection"
self.mode = "nearest" self.mode = "nearest"
self.numeric_grad_delta = 0.0001 self.numeric_grad_delta = 0.0001
......
...@@ -249,7 +249,7 @@ def grid_sample(x, ...@@ -249,7 +249,7 @@ def grid_sample(x,
mode(str, optional): The interpolation method which can be 'bilinear' or 'nearest'. mode(str, optional): The interpolation method which can be 'bilinear' or 'nearest'.
Default: 'bilinear'. Default: 'bilinear'.
padding_mode(str, optional) The padding method used when source index padding_mode(str, optional) The padding method used when source index
is out of input images. It can be 'zeros', 'reflect' and 'border'. is out of input images. It can be 'zeros', 'reflection' and 'border'.
Default: zeros. Default: zeros.
align_corners(bool, optional): If `align_corners` is true, it will projects align_corners(bool, optional): If `align_corners` is true, it will projects
-1 and 1 to the centers of the corner pixels. Otherwise, it will -1 and 1 to the centers of the corner pixels. Otherwise, it will
...@@ -312,7 +312,7 @@ def grid_sample(x, ...@@ -312,7 +312,7 @@ def grid_sample(x,
if not isinstance(grid, Variable): if not isinstance(grid, Variable):
raise ValueError("The grid should be a Variable") raise ValueError("The grid should be a Variable")
_modes = ['bilinear', 'nearest'] _modes = ['bilinear', 'nearest']
_padding_modes = ['zeros', 'reflect', 'border'] _padding_modes = ['zeros', 'reflection', 'border']
if mode not in _modes: if mode not in _modes:
raise ValueError( raise ValueError(
"The mode of grid sample function should be in {}, but got: {}". "The mode of grid sample function should be in {}, but got: {}".
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册