未验证 提交 f2c21ff9 编写于 作者: W whs 提交者: GitHub

[cherry-pick 2.0API]Add checker in grid_sample_grad op (#27127)

上级 fc8bc1ba
......@@ -115,7 +115,7 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::string>(
"padding_mode",
"(bool, default true) The padding method used when source"
"index is out of input images. It can be 'zeros', 'reflect' and "
"index is out of input images. It can be 'zeros', 'reflection' and "
"'border'.")
.SetDefault("zeros");
......@@ -174,6 +174,10 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "grid_sampler");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Grid")), "Output",
framework::GradVarName("Grid"), "grid_sampler");
auto input_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");
if (ctx->HasOutput(framework::GradVarName("X"))) {
......
......@@ -268,7 +268,7 @@ class GridSampleOpCUDAKernel : public framework::OpKernel<T> {
Mode mode;
if (padding_mode_s == "border") {
padding_mode = PaddingMode::border;
} else if (padding_mode_s == "reflect") {
} else if (padding_mode_s == "reflection") {
padding_mode = PaddingMode::reflect;
} else {
padding_mode = PaddingMode::zeros;
......@@ -432,7 +432,7 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
Mode mode;
if (padding_mode_s == "border") {
padding_mode = PaddingMode::border;
} else if (padding_mode_s == "reflect") {
} else if (padding_mode_s == "reflection") {
padding_mode = PaddingMode::reflect;
} else {
padding_mode = PaddingMode::zeros;
......
......@@ -76,7 +76,7 @@ static inline void clip(const platform::CPUDeviceContext& ctx,
if (padding_mode == "border") {
grid_slice_t.device(place) = grid_slice_t.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(max_val));
} else if (padding_mode == "reflect") {
} else if (padding_mode == "reflection") {
if (align_corners) {
auto double_range = static_cast<T>(max_val * 2);
auto grid_abs = grid_slice_t.abs();
......@@ -117,7 +117,7 @@ static inline void clipWithMask(const platform::CPUDeviceContext& ctx,
auto in_bound = (res == grid_slice_t);
grid_scale_t.device(place) = grid_scale_t * in_bound.template cast<T>();
grid_slice_t.device(place) = res;
} else if (padding_mode == "reflect") {
} else if (padding_mode == "reflection") {
if (align_corners) {
auto double_range = static_cast<T>(max_val * 2);
auto is_neg = (grid_slice_t < static_cast<T>(0));
......
......@@ -100,7 +100,7 @@ def add_cases(suite):
GridSampleTestCase(
methodName='runTest',
mode='bilinear',
padding_mode='reflect',
padding_mode='reflection',
align_corners=True))
suite.addTest(
GridSampleTestCase(
......
......@@ -73,7 +73,7 @@ def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode):
if padding_mode == "border":
grid_slice = clip(grid_slice, 0, max_val)
elif padding_mode == "reflect":
elif padding_mode == "reflection":
double_range = 2 * max_val if align_corners else (max_val + 1) * 2
grid_abs = np.abs(grid_slice) if align_corners else np.abs(grid_slice +
0.5)
......@@ -211,7 +211,7 @@ class Case2(TestGridSamplerOp):
self.grid_shape = (2, 8, 9, 2)
self.theta_shape = (2, 2, 3)
self.align_corners = False
self.padding_mode = "reflect"
self.padding_mode = "reflection"
self.mode = "bilinear"
......@@ -221,7 +221,7 @@ class Case3(TestGridSamplerOp):
self.grid_shape = (2, 8, 9, 2)
self.theta_shape = (2, 2, 3)
self.align_corners = True
self.padding_mode = "reflect"
self.padding_mode = "reflection"
self.mode = "bilinear"
......@@ -231,7 +231,7 @@ class Case4(TestGridSamplerOp):
self.grid_shape = (2, 8, 9, 2)
self.theta_shape = (2, 2, 3)
self.align_corners = False
self.padding_mode = "reflect"
self.padding_mode = "reflection"
self.mode = "nearest"
self.numeric_grad_delta = 0.0001
......
......@@ -249,7 +249,7 @@ def grid_sample(x,
mode(str, optional): The interpolation method which can be 'bilinear' or 'nearest'.
Default: 'bilinear'.
padding_mode(str, optional) The padding method used when source index
is out of input images. It can be 'zeros', 'reflect' and 'border'.
is out of input images. It can be 'zeros', 'reflection' and 'border'.
Default: zeros.
align_corners(bool, optional): If `align_corners` is true, it will projects
-1 and 1 to the centers of the corner pixels. Otherwise, it will
......@@ -312,7 +312,7 @@ def grid_sample(x,
if not isinstance(grid, Variable):
raise ValueError("The grid should be a Variable")
_modes = ['bilinear', 'nearest']
_padding_modes = ['zeros', 'reflect', 'border']
_padding_modes = ['zeros', 'reflection', 'border']
if mode not in _modes:
raise ValueError(
"The mode of grid sample function should be in {}, but got: {}".
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册