diff --git a/paddle/fluid/operators/erfinv_op.cc b/paddle/fluid/operators/erfinv_op.cc index f9489a7bd09470308bcc2c7dcba49d34a4bfc291..3d409b4c4f6772bc7b234208e78c5088eeb2fc00 100644 --- a/paddle/fluid/operators/erfinv_op.cc +++ b/paddle/fluid/operators/erfinv_op.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -20,14 +23,6 @@ namespace operators { class ErfinvOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "erfinv"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "erfinv"); - - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - ctx->ShareLoD("X", /*->*/ "Out"); - } }; class ErfinvOpMaker : public framework::OpProtoAndCheckerMaker { @@ -78,10 +73,13 @@ DECLARE_INPLACE_OP_INFERER(ErfinvInplaceInferer, {"X", "Out"}); } // namespace operators } // namespace paddle +DELCARE_INFER_SHAPE_FUNCTOR(erfinv, ErfinvInferShapeFunctor, + PT_INFER_META(phi::UnchangedInferMeta)); + REGISTER_OPERATOR( erfinv, paddle::operators::ErfinvOp, paddle::operators::ErfinvOpMaker, paddle::operators::ErfinvGradMaker, paddle::operators::ErfinvGradMaker, - paddle::operators::ErfinvInplaceInferer); + paddle::operators::ErfinvInplaceInferer, ErfinvInferShapeFunctor); REGISTER_OPERATOR(erfinv_grad, paddle::operators::ErfinvGradOp); diff --git a/paddle/fluid/operators/pixel_shuffle_op.cc b/paddle/fluid/operators/pixel_shuffle_op.cc index f5c53ab320db787899b0fedad6b798f0db758ec8..2a127d9ad1db0c1e169fdd1e20a1568b99d228a0 100644 --- a/paddle/fluid/operators/pixel_shuffle_op.cc +++ b/paddle/fluid/operators/pixel_shuffle_op.cc @@ -10,8 +10,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -19,56 +22,6 @@ namespace operators { class PixelShuffleOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::NotFound( - "Input(X) of PixelShuffleOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::NotFound( - "Output(Out) of PixelShuffleOp should not be null.")); - - auto input_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ(input_dims.size(), 4, - platform::errors::InvalidArgument( - "Input should be a 4-D tensor of format [N, C, H, W] " - "or [N, H, W, C], but got %u.", - input_dims.size())); - - auto upscale_factor = ctx->Attrs().Get("upscale_factor"); - - const std::string data_format = - ctx->Attrs().Get("data_format"); - const bool channel_last = (data_format == "NHWC"); - - if (!channel_last) { - PADDLE_ENFORCE_EQ( - input_dims[1] % (upscale_factor * upscale_factor), 0, - platform::errors::InvalidArgument( - "The square of upscale_factor[%u] should divide the " - "number of channel[%u]", - upscale_factor * upscale_factor, input_dims[1])); - } else { - PADDLE_ENFORCE_EQ( - input_dims[3] % (upscale_factor * upscale_factor), 0, - platform::errors::InvalidArgument( - "The square of upscale_factor[%u] should divide the " - "number of channel[%u]", - upscale_factor * upscale_factor, input_dims[3])); - } - auto output_dims = input_dims; - output_dims[0] = input_dims[0]; - if (!channel_last) { - output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor); - output_dims[2] = input_dims[2] * upscale_factor; - output_dims[3] = input_dims[3] * upscale_factor; - } else { - output_dims[1] = input_dims[1] * upscale_factor; - output_dims[2] = input_dims[2] * upscale_factor; - output_dims[3] = input_dims[3] / (upscale_factor * upscale_factor); - } - ctx->SetOutputDim("Out", output_dims); - } }; class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker { @@ -171,9 +124,13 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(pixel_shuffle, PixelShuffleInferShapeFunctor, + PT_INFER_META(phi::PixelShuffleInferMeta)); + REGISTER_OPERATOR(pixel_shuffle, ops::PixelShuffleOp, ops::PixelShuffleOpMaker, ops::PixelShuffleGradMaker, - ops::PixelShuffleGradMaker); + ops::PixelShuffleGradMaker, + PixelShuffleInferShapeFunctor); REGISTER_OPERATOR(pixel_shuffle_grad, ops::PixelShuffleGradOp); diff --git a/paddle/fluid/operators/size_op.cc b/paddle/fluid/operators/size_op.cc index c27a02e46f332dc4e6374a201170b0178f87f463..e584c1a4cce1e85344c574526098b034723c3059 100644 --- a/paddle/fluid/operators/size_op.cc +++ b/paddle/fluid/operators/size_op.cc @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -20,13 +23,6 @@ namespace operators { class SizeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Size"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Size"); - - ctx->SetOutputDim("Out", {1}); - } }; class SizeOpMaker : public framework::OpProtoAndCheckerMaker { @@ -48,7 +44,10 @@ Return the number of elements in the input. } // namespace paddle namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(size, SizeInferShapeFunctor, + PT_INFER_META(phi::SizeInferMeta)); REGISTER_OPERATOR( size, ops::SizeOp, ops::SizeOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker, + SizeInferShapeFunctor); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 1a9dbf90dd45f2c4669a918851590191367c79cd..49fd0a343a470f2545fc563366256f4f92294297 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -856,6 +856,57 @@ void DiagInferMeta(const MetaTensor& x, } } +void SizeInferMeta(const MetaTensor& input, MetaTensor* out) { + out->set_dtype(DataType::INT64); + out->set_dims({1}); +} + +void PixelShuffleInferMeta(const MetaTensor& x, + int upscale_factor, + const std::string& data_format, + MetaTensor* out) { + auto input_dims = x.dims(); + PADDLE_ENFORCE_EQ(input_dims.size(), + 4, + phi::errors::InvalidArgument( + "Input should be a 4-D tensor of format [N, C, H, W] " + "or [N, H, W, C], but got %u.", + input_dims.size())); + + const bool channel_last = (data_format == "NHWC"); + + if (!channel_last) { + PADDLE_ENFORCE_EQ(input_dims[1] % (upscale_factor * upscale_factor), + 0, + phi::errors::InvalidArgument( + "The square of upscale_factor[%u] should divide the " + "number of channel[%u]", + upscale_factor * upscale_factor, + input_dims[1])); + } else { + PADDLE_ENFORCE_EQ(input_dims[3] % (upscale_factor * upscale_factor), + 0, + phi::errors::InvalidArgument( + "The square of upscale_factor[%u] should divide the " + "number of channel[%u]", + upscale_factor * upscale_factor, + input_dims[3])); + } + auto output_dims = input_dims; + output_dims[0] = input_dims[0]; + if (!channel_last) { + output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor); + output_dims[2] = input_dims[2] * upscale_factor; + output_dims[3] = input_dims[3] * upscale_factor; + } else { + output_dims[1] = input_dims[1] * upscale_factor; + output_dims[2] = input_dims[2] * upscale_factor; + output_dims[3] = input_dims[3] / (upscale_factor * upscale_factor); + } + out->set_dtype(x.dtype()); + out->set_dims(output_dims); +} + } // namespace phi PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 172ea2a565e18503399c8f4475b4b519a23ce4fd..4fab1ec68ec1e71af5e55a9852cd68deccc09a7c 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -129,4 +129,11 @@ void DiagInferMeta(const MetaTensor& x, float padding_value, MetaTensor* out); +void SizeInferMeta(const MetaTensor& input, MetaTensor* out); + +void PixelShuffleInferMeta(const MetaTensor& x, + int upscale_factor, + const std::string& data_format, + MetaTensor* out); + } // namespace phi