未验证 提交 90280542 编写于 作者: W Wang Xin 提交者: GitHub

add autogen code support for affine_grid op (#52560)

* add autogen code support for affine_grid op

* update op_compat.yaml for affine_grid

* update op_compat.yaml for affine_grid

* fix AffineGridGradInferMeta

* fix CI error

* update AffineGridInferMeta
上级 ec008a71
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class AffineGridOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Theta"),
true,
platform::errors::NotFound(
"The input 'Theta' of AffineGridOp is not found."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Output"),
true,
platform::errors::NotFound(
"The output 'Output' of AffineGridOp is not found."));
auto theta_dims = ctx->GetInputDim("Theta");
PADDLE_ENFORCE_EQ(
theta_dims.size(),
3,
platform::errors::InvalidArgument(
"The input Theta's dimensions size should be 3. But received "
"Theta's demensions size=[%d], Theta's dimensions=[%s].",
theta_dims.size(),
theta_dims));
auto output_shape = ctx->Attrs().Get<std::vector<int>>("output_shape");
if (output_shape.size() == 0) {
PADDLE_ENFORCE_EQ(
ctx->HasInput("OutputShape"),
true,
platform::errors::NotFound(
"The input 'OutputShape' of AffineGridOp should not be null if "
"'output_shape' is not configured."));
auto output_shape_dims = ctx->GetInputDim("OutputShape");
PADDLE_ENFORCE_EQ(
output_shape_dims.size(),
1,
platform::errors::InvalidArgument(
"The dimesions size of input OutputShape in AffineGridOp should "
"be 1. But received OutputShape's dimesions size=[%d], "
"OutputShape's dimesions=[%s]",
output_shape_dims.size(),
output_shape_dims));
} else {
PADDLE_ENFORCE_GE(output_shape.size(),
4,
platform::errors::InvalidArgument(
"The size of attribute 'output_shape' in "
"AffineGridOp should be >= "
"4. But received output_shape's size=[%d].",
output_shape.size()));
PADDLE_ENFORCE_LE(output_shape.size(),
5,
platform::errors::InvalidArgument(
"The size of attribute 'output_shape' in "
"AffineGridOp should be <= "
"5. But received output_shape's size=[%d].",
output_shape.size()));
}
PADDLE_ENFORCE_GE(theta_dims[1],
2,
platform::errors::InvalidArgument(
"The second dimesion of input 'theta' in "
"AffineGridOp should be >= 2. "
"But received second dimesion=[%d], dimesions=[%s]",
theta_dims[1],
theta_dims));
PADDLE_ENFORCE_LE(theta_dims[1],
3,
platform::errors::InvalidArgument(
"The second dimesion of input 'theta' in "
"AffineGridOp should be <= 3. "
"But received second dimesion=[%d], dimesions=[%s]",
theta_dims[1],
theta_dims));
PADDLE_ENFORCE_GE(theta_dims[2],
3,
platform::errors::InvalidArgument(
"The third dimesion of input 'theta' in AffineGridOp "
"should be >= 3. "
"But received third dimesion=[%d], dimesions=[%s]",
theta_dims[2],
theta_dims));
PADDLE_ENFORCE_LE(theta_dims[2],
4,
platform::errors::InvalidArgument(
"The third dimesion of input 'theta' in AffineGridOp "
"should be <= 4. "
"But received third dimesion=[%d], dimesions=[%s]",
theta_dims[2],
theta_dims));
if (output_shape.size() == 4) {
// N * H * W * 2
ctx->SetOutputDim("Output", phi::make_ddim({theta_dims[0], -1, -1, 2}));
} else {
// N * D * H * W * 3
ctx->SetOutputDim("Output",
phi::make_ddim({theta_dims[0], -1, -1, -1, 3}));
}
ctx->ShareLoD("Theta", "Output");
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta");
return phi::KernelKey(data_type, ctx.GetPlace());
}
};
class AffineGridOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"Theta",
"(Tensor) A batch of affine transform parameters with shape [N, 2, 3]. "
"It is used to transform coordinate (x_0, y_0) to coordinate (x_1, "
"y_1).");
AddInput("OutputShape",
"(Tensor) The shape of target image with format [N, C, H, W].")
.AsDispensable();
AddOutput("Output", "(Tensor) Output Tensor with shape [N, H, W, 2].");
AddAttr<bool>("align_corners",
"(bool, default false) Whether to align the corners of input"
"and output.")
.SetDefault(true);
AddAttr<std::vector<int>>(
"output_shape",
"The target output image shape with format [N, C, H, W].")
.SetDefault(std::vector<int>());
AddComment(R"DOC(
It generates a grid of (x,y) coordinates using the parameters of the
affine transformation that correspond to a set of points where the input
feature map should be sampled to produce the transformed output feature map.
Given:
Theta = [[[x_11, x_12, x_13]
[x_14, x_15, x_16]]
[[x_21, x_22, x_23]
[x_24, x_25, x_26]]]
OutputShape = [2, 3, 5, 5]
Step 1:
Generate relative coordinates according to OutputShape.
The values of relative coordinates are in the interval between -1 and 1.
The shape of the relative coordinates is [2, H, W] as below:
C = [[[-1. -1. -1. -1. -1. ]
[-0.5 -0.5 -0.5 -0.5 -0.5]
[ 0. 0. 0. 0. 0. ]
[ 0.5 0.5 0.5 0.5 0.5]
[ 1. 1. 1. 1. 1. ]]
[[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]
[-1. -0.5 0. 0.5 1. ]]]
C[0] is the coordinates in height axis and C[1] is the coordinates in
width axis.
Step2:
Tanspose and reshape C to shape [H * W, 2] and append ones to last
dimension. The we get:
C_ = [[-1. -1. 1. ]
[-0.5 -1. 1. ]
[ 0. -1. 1. ]
[ 0.5 -1. 1. ]
[ 1. -1. 1. ]
[-1. -0.5 1. ]
[-0.5 -0.5 1. ]
[ 0. -0.5 1. ]
[ 0.5 -0.5 1. ]
[ 1. -0.5 1. ]
[-1. 0. 1. ]
[-0.5 0. 1. ]
[ 0. 0. 1. ]
[ 0.5 0. 1. ]
[ 1. 0. 1. ]
[-1. 0.5 1. ]
[-0.5 0.5 1. ]
[ 0. 0.5 1. ]
[ 0.5 0.5 1. ]
[ 1. 0.5 1. ]
[-1. 1. 1. ]
[-0.5 1. 1. ]
[ 0. 1. 1. ]
[ 0.5 1. 1. ]
[ 1. 1. 1. ]]
Step3:
Compute output by equation $$Output[i] = C_ * Theta[i]^T$$
)DOC");
}
};
class AffineGridOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->HasOutput(framework::GradVarName("Theta"))) {
auto output_dims = ctx->GetInputDim(framework::GradVarName("Output"));
if (output_dims.size() == 4) {
ctx->SetOutputDim(framework::GradVarName("Theta"),
{output_dims[0], 2, 3});
} else {
ctx->SetOutputDim(framework::GradVarName("Theta"),
{output_dims[0], 3, 4});
}
}
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Output"));
return phi::KernelKey(data_type, ctx.GetPlace());
}
};
template <typename T>
class AffineGridGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("affine_grid_grad");
op->SetInput("OutputShape", this->Input("OutputShape"));
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("Theta"), this->InputGrad("Theta"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(affine_grid,
ops::AffineGridOp,
ops::AffineGridOpMaker,
ops::AffineGridGradMaker<paddle::framework::OpDesc>,
ops::AffineGridGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(affine_grid_grad, ops::AffineGridOpGrad);
REGISTER_OP_VERSION(affine_grid)
.AddCheckpoint(
R"ROC(
Compatible upgrade of affine_grid, add a new attribute [align_corners])ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"align_corners",
"Whether to align the corners of input and output.",
true));
......@@ -33,6 +33,17 @@
kernel :
func : addmm_grad
- backward_op : affine_grid_grad
forward : affine_grid (Tensor input, IntArray output_shape={}, bool align_corners=true) -> Tensor(output)
args : (Tensor input, Tensor output_grad, IntArray output_shape, bool align_corners=true)
output : Tensor(input_grad)
infer_meta :
func : AffineGridGradInferMeta
param : [output_grad, output_shape, align_corners]
kernel :
func : affine_grid_grad
param : [output_grad, output_shape, align_corners]
- backward_op : angle_grad
forward : angle (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
......@@ -58,18 +58,6 @@
func : add_triple_grad
inplace : (grad_grad_out_grad -> grad_grad_x_grad)
- backward_op : affine_grid_grad
forward : affine_grid (Tensor input, IntArray outputShape, bool align_corners=true) -> Tensor(output)
args : (Tensor input, Tensor output_grad, IntArray outputShape, bool align_corners=true)
output : Tensor(input_grad)
infer_meta :
func : AffineGridGradInferMeta
param : [output_grad, outputShape, align_corners]
kernel :
func : affine_grid_grad
param : [output_grad, outputShape, align_corners]
no_need_buffer : input
- backward_op : amax_grad
forward: amax (Tensor x, int64_t[] axis={}, bool keepdim=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis={}, bool keepdim=false, bool reduce_all=false)
......
......@@ -83,18 +83,6 @@
invoke : add_n_impl(inputs)
backward : add_n_grad
- op : affine_grid
args : (Tensor input, IntArray outputShape, bool align_corners=true)
output : Tensor
infer_meta :
func : AffineGridInferMeta
param : [input, outputShape, align_corners]
kernel :
func : affine_grid
param : [input, outputShape, align_corners]
data_type : input
backward : affine_grid_grad
- op : all
args : (Tensor x, int64_t[] axis={}, bool keepdim=false)
output : Tensor(out)
......
......@@ -63,6 +63,14 @@
- op : affine_grid
backward : affine_grid_grad
inputs :
input : Theta
outputs :
out : Output
int_array:
output_shape :
data_type : int
tensor_name : OutputShape
extra :
attrs : [bool use_cudnn = true]
......
- op : affine_grid
version :
- checkpoint : Compatible upgrade of affine_grid, add a new attribute [align_corners].
action :
- add_attr : align_corners
comment : Whether to align the corners of input and output.
default : "true"
- op : allclose
version :
- checkpoint : Upgrade allclose, add two new inputs [Rtol] and [Atol].
......
......@@ -42,6 +42,18 @@
data_type : x
backward : addmm_grad
- op : affine_grid
args : (Tensor input, IntArray output_shape={}, bool align_corners=true)
output : Tensor
infer_meta :
func : AffineGridInferMeta
param : [input, output_shape, align_corners]
kernel :
func : affine_grid
param : [input, output_shape, align_corners]
data_type : input
backward : affine_grid_grad
- op : allclose
args : (Tensor x, Tensor y, Scalar rtol="1e-5", Scalar atol="1e-8", bool equal_nan=false)
output : Tensor(out)
......
......@@ -25,7 +25,11 @@ void AffineGridGradInferMeta(const MetaTensor& output_grad,
MetaTensor* input_grad) {
if (input_grad) {
auto output_dims = output_grad.dims();
input_grad->set_dims(phi::make_ddim({output_dims[0], 2, 3}));
if (output_dims.size() == 4) {
input_grad->set_dims(phi::make_ddim({output_dims[0], 2, 3}));
} else {
input_grad->set_dims(phi::make_ddim({output_dims[0], 3, 4}));
}
}
}
......
......@@ -49,31 +49,33 @@ void AffineGridInferMeta(const MetaTensor& input,
bool align_corners,
MetaTensor* output) {
auto theta_dims = input.dims();
PADDLE_ENFORCE_EQ(
theta_dims.size(),
3,
phi::errors::InvalidArgument(
"The input Theta's dimensions size should be 3. But received "
"Theta's demensions size=[%d], Theta's dimensions=[%s].",
theta_dims.size(),
theta_dims));
PADDLE_ENFORCE_GE(
outputShape.GetData().size(),
4,
phi::errors::InvalidArgument(
"The size of attribute 'output_shape' in AffineGridOp should be >= "
"4. But received output_shape's size=[%d].",
outputShape.GetData().size()));
bool is_from_tensor = outputShape.FromTensor();
if (!is_from_tensor) {
PADDLE_ENFORCE_EQ(
theta_dims.size(),
3,
phi::errors::InvalidArgument(
"The input Theta's dimensions size should be 3. But received "
"Theta's demensions size=[%d], Theta's dimensions=[%s].",
theta_dims.size(),
theta_dims));
PADDLE_ENFORCE_LE(
outputShape.GetData().size(),
5,
phi::errors::InvalidArgument(
"The size of attribute 'output_shape' in AffineGridOp should be <= "
"5. But received output_shape's size=[%d].",
outputShape.GetData().size()));
PADDLE_ENFORCE_GE(
outputShape.GetData().size(),
4,
phi::errors::InvalidArgument(
"The size of attribute 'output_shape' in AffineGridOp should be >= "
"4. But received output_shape's size=[%d].",
outputShape.GetData().size()));
PADDLE_ENFORCE_LE(
outputShape.GetData().size(),
5,
phi::errors::InvalidArgument(
"The size of attribute 'output_shape' in AffineGridOp should be <= "
"5. But received output_shape's size=[%d].",
outputShape.GetData().size()));
}
PADDLE_ENFORCE_GE(theta_dims[1],
2,
phi::errors::InvalidArgument(
......@@ -109,7 +111,7 @@ void AffineGridInferMeta(const MetaTensor& input,
"But received third dimesion=[%d], dimesions=[%s]",
theta_dims[2],
theta_dims));
if (outputShape.GetData().size() == 4) {
if (outputShape.GetData().size() == 4 && !is_from_tensor) {
// N * H * W * 2
output->set_dims(phi::make_ddim({theta_dims[0], -1, -1, 2}));
} else {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature AffineGridOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("OutputShape")) {
return KernelSignature(
"affine_grid", {"Theta"}, {"OutputShape", "align_corners"}, {"Output"});
} else {
return KernelSignature("affine_grid",
{"Theta"},
{"output_shape", "align_corners"},
{"Output"});
}
}
KernelSignature AffineGridGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.HasInput("OutputShape")) {
return KernelSignature("affine_grid_grad",
{"Output@GRAD"},
{"OutputShape", "align_corners"},
{"Theta@GRAD"});
} else {
return KernelSignature("affine_grid_grad",
{"Output@GRAD"},
{"output_shape", "align_corners"},
{"Theta@GRAD"});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(affine_grid, phi::AffineGridOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(affine_grid_grad,
phi::AffineGridGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册