未验证 提交 b1b24463 编写于 作者: C Chen Weihang 提交者: GitHub

move grid sample op infershape (#40625)

上级 827b6a0e
......@@ -15,9 +15,13 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -27,43 +31,6 @@ using Tensor = framework::Tensor;
class GridSampleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GridSampler");
OP_INOUT_CHECK(ctx->HasInput("Grid"), "Input", "Grid", "GridSampler");
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "GridSampler");
auto x_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");
PADDLE_ENFORCE_EQ(x_dims.size(), 4,
platform::errors::InvalidArgument(
"Input(X) of GridSampleOp should be 4-D Tensor, but "
"received X dimension size(%d)",
x_dims.size()));
PADDLE_ENFORCE_EQ(grid_dims.size(), 4,
platform::errors::InvalidArgument(
"Input(Grid) of GridSampleOp should be 4-D Tensor, "
"but received X dimension size(%d)",
grid_dims.size()));
if (ctx->IsRuntime() || grid_dims[3] > 0) {
PADDLE_ENFORCE_EQ(
grid_dims[3], 2,
platform::errors::InvalidArgument(
"Input(Grid) dimension[3] should be 2, but received %d",
grid_dims[3]));
}
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
grid_dims[0], x_dims[0],
platform::errors::InvalidArgument(
"Input(X) and Input(Grid) dimension[0] should be equal, but "
"received X dimension[0](%d) != Grid dimension[0](%d)",
x_dims[0], grid_dims[0]));
}
ctx->SetOutputDim("Output",
{x_dims[0], x_dims[1], grid_dims[1], grid_dims[2]});
ctx->ShareLoD("X", "Output");
}
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -173,18 +140,6 @@ class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
class GridSampleOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "grid_sampler");
auto input_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
}
if (ctx->HasOutput(framework::GradVarName("Grid"))) {
ctx->SetOutputDim(framework::GradVarName("Grid"), grid_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -224,10 +179,16 @@ class GridSampleGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(grid_sampler, GridSamplerInferShapeFunctor,
PD_INFER_META(phi::GridSampleBaseInferMeta));
REGISTER_OPERATOR(grid_sampler, ops::GridSampleOp, ops::GridSampleOpMaker,
ops::GridSampleGradMaker<paddle::framework::OpDesc>,
ops::GridSampleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(grid_sampler_grad, ops::GridSampleOpGrad);
ops::GridSampleGradMaker<paddle::imperative::OpBase>,
GridSamplerInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(grid_sampler_grad, GridSamplerGradInferShapeFunctor,
PD_INFER_META(phi::GeneralBinaryGradInferMeta));
REGISTER_OPERATOR(grid_sampler_grad, ops::GridSampleOpGrad,
GridSamplerGradInferShapeFunctor);
REGISTER_OP_VERSION(grid_sampler)
.AddCheckpoint(
......
......@@ -571,6 +571,48 @@ void GatherTreeMeta(const MetaTensor& ids,
out->set_dims(ids_dims);
}
void GridSampleBaseInferMeta(const MetaTensor& x,
const MetaTensor& grid,
MetaTensor* out,
MetaConfig config) {
auto x_dims = x.dims();
auto grid_dims = grid.dims();
PADDLE_ENFORCE_EQ(x_dims.size(),
4,
phi::errors::InvalidArgument(
"Input(X) of GridSampleOp should be 4-D Tensor, but "
"received X dimension size(%d)",
x_dims.size()));
PADDLE_ENFORCE_EQ(grid_dims.size(),
4,
phi::errors::InvalidArgument(
"Input(Grid) of GridSampleOp should be 4-D Tensor, "
"but received X dimension size(%d)",
grid_dims.size()));
if (config.is_runtime || grid_dims[3] > 0) {
PADDLE_ENFORCE_EQ(
grid_dims[3],
2,
phi::errors::InvalidArgument(
"Input(Grid) dimension[3] should be 2, but received %d",
grid_dims[3]));
}
if (config.is_runtime) {
PADDLE_ENFORCE_EQ(
grid_dims[0],
x_dims[0],
phi::errors::InvalidArgument(
"Input(X) and Input(Grid) dimension[0] should be equal, but "
"received X dimension[0](%d) != Grid dimension[0](%d)",
x_dims[0],
grid_dims[0]));
}
out->set_dims({x_dims[0], x_dims[1], grid_dims[1], grid_dims[2]});
out->set_dtype(x.dtype());
out->share_lod(x);
}
void HuberLossInferMeta(const MetaTensor& input,
const MetaTensor& label,
float delta,
......
......@@ -103,6 +103,11 @@ void GatherTreeMeta(const MetaTensor& ids,
const MetaTensor& parents,
MetaTensor* out);
void GridSampleBaseInferMeta(const MetaTensor& x,
const MetaTensor& grid,
MetaTensor* out,
MetaConfig config = MetaConfig());
void HuberLossInferMeta(const MetaTensor& input_meta,
const MetaTensor& label_meta,
float delta,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册