From 1c2058834367464b4a293dbb58b6fa2137c24cc5 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Thu, 3 Mar 2022 20:23:51 +0800 Subject: [PATCH] move eye, lerp infershape to phi (#40105) --- paddle/fluid/operators/eye_op.cc | 26 ++++-------- paddle/fluid/operators/lerp_op.cc | 50 +++-------------------- paddle/phi/infermeta/nullary.cc | 8 ++++ paddle/phi/infermeta/nullary.h | 5 +++ paddle/phi/infermeta/ternary.cc | 17 ++++++++ paddle/phi/infermeta/ternary.h | 5 +++ paddle/phi/kernels/eye_kernel.h | 2 +- paddle/phi/kernels/funcs/common_shape.h | 25 ++++++++++++ paddle/phi/kernels/impl/eye_kernel_impl.h | 2 +- 9 files changed, 75 insertions(+), 65 deletions(-) diff --git a/paddle/fluid/operators/eye_op.cc b/paddle/fluid/operators/eye_op.cc index 8f8a0f174a7..f8c6b4eb8c5 100644 --- a/paddle/fluid/operators/eye_op.cc +++ b/paddle/fluid/operators/eye_op.cc @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/nullary.h" namespace paddle { namespace operators { @@ -21,24 +24,6 @@ class EyeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument( - "Output(Out) of EyeOP should not be null.")); - auto num_rows = ctx->Attrs().Get("num_rows"); - PADDLE_ENFORCE_EQ( - num_rows >= 0, true, - platform::errors::InvalidArgument( - "The value of Input(num_rows) should be non-negative int.")); - auto num_columns = ctx->Attrs().Get("num_columns"); - if (num_columns == -1) num_columns = num_rows; - PADDLE_ENFORCE_EQ( - num_columns >= 0, true, - platform::errors::InvalidArgument( - "The value of Input(num_columns) should be non-negative int.")); - ctx->SetOutputDim("Out", {num_rows, num_columns}); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -82,8 +67,11 @@ Return an identity tensor whose shape is [num_rows, num_columns]. } // namespace paddle namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(eye, EyeInferShapeFunctor, + PT_INFER_META(phi::EyeInferMeta)); REGISTER_OPERATOR( eye, ops::EyeOp, ops::EyeOpMaker, ops::EyeOpVarTypeInference, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker, + EyeInferShapeFunctor); diff --git a/paddle/fluid/operators/lerp_op.cc b/paddle/fluid/operators/lerp_op.cc index 0aaefc7ca75..fef6fc5319e 100644 --- a/paddle/fluid/operators/lerp_op.cc +++ b/paddle/fluid/operators/lerp_op.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/ternary.h" namespace paddle { namespace operators { @@ -20,49 +23,6 @@ namespace operators { class LerpOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lerp"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "lerp"); - OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "lerp"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "lerp"); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - auto w_dims = ctx->GetInputDim("Weight"); - framework::DDim out_dims; - out_dims = GetOutputDims(x_dims, y_dims); - if (w_dims.size() > 1 || w_dims[0] != 1) { - out_dims = GetOutputDims(out_dims, w_dims); - } - - ctx->SetOutputDim("Out", out_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } - - private: - framework::DDim GetOutputDims(const framework::DDim& s_dims, - const framework::DDim& l_dims) const { - if (s_dims.size() > l_dims.size()) { - return GetOutputDims(l_dims, s_dims); - } - std::vector shapes = phi::vectorize(l_dims); - for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) { - int64_t s = s_dims[i]; - int64_t l = l_dims[j]; - if (s != l) { - if (l == 1) { - shapes[j] = s; - } else if (s != 1) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The shape of tensor a %s:%d must match shape of tensor b " - "%s:%d.", - s_dims.to_str(), i, l_dims.to_str(), j)); - } - } - } - return phi::make_ddim(shapes); - } }; class LerpOpMaker : public framework::OpProtoAndCheckerMaker { @@ -125,10 +85,12 @@ DECLARE_INPLACE_OP_INFERER(LerpInplaceInferer, {"X", "Out"}); } // namespace operators } // namespace paddle +DELCARE_INFER_SHAPE_FUNCTOR(lerp, LerpInferShapeFunctor, + PT_INFER_META(phi::LerpInferMeta)); REGISTER_OPERATOR( lerp, paddle::operators::LerpOp, paddle::operators::LerpOpMaker, paddle::operators::LerpOpGradMaker, paddle::operators::LerpOpGradMaker, - paddle::operators::LerpInplaceInferer); + paddle::operators::LerpInplaceInferer, LerpInferShapeFunctor); REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp); diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index 1fdf8a6940a..0c48c9d0c7e 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -32,4 +32,12 @@ void CreateInferMeta(const ScalarArray& shape, CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out); } +void EyeInferMeta(int64_t num_rows, + int64_t num_columns, + DataType dtype, + MetaTensor* out) { + if (num_columns == -1) num_columns = num_rows; + out->set_dims({num_rows, num_columns}); + out->set_dtype(dtype); +} } // namespace phi diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index ea5bb71551b..40d6ea595c0 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -35,4 +35,9 @@ void CreateInferMetaBase(const std::vector& shape, void CreateInferMeta(const ScalarArray& shape, DataType dtype, MetaTensor* out); +void EyeInferMeta(int64_t num_rows, + int64_t num_columns, + DataType dtype, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 52aeaef8438..1c1497fb0e4 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -89,4 +89,21 @@ void AddmmInferMeta(const MetaTensor& input, out->set_dtype(input.dtype()); } +void LerpInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& weight, + MetaTensor* out) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + auto w_dims = weight.dims(); + DDim out_dims; + out_dims = funcs::GetOutputDims(x_dims, y_dims); + if (w_dims.size() > 1 || w_dims[0] != 1) { + out_dims = funcs::GetOutputDims(out_dims, w_dims); + } + out->set_dims(out_dims); + out->set_dtype(x.dtype()); + out->share_lod(x); +} + } // namespace phi diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index d6223dd87aa..5679c5b533f 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -37,4 +37,9 @@ void AddmmInferMeta(const MetaTensor& input, float beta, MetaTensor* out); +void LerpInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& weight, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/eye_kernel.h b/paddle/phi/kernels/eye_kernel.h index 8b21b8ae405..e9e1abffd14 100644 --- a/paddle/phi/kernels/eye_kernel.h +++ b/paddle/phi/kernels/eye_kernel.h @@ -22,7 +22,7 @@ template void EyeKernel(const Context& ctx, int64_t num_rows, int64_t num_columns, - int dtype, + DataType dtype, DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/funcs/common_shape.h b/paddle/phi/kernels/funcs/common_shape.h index d5289dcc22c..dce80caab72 100644 --- a/paddle/phi/kernels/funcs/common_shape.h +++ b/paddle/phi/kernels/funcs/common_shape.h @@ -140,5 +140,30 @@ inline bool CheckDims(const DDim &dims_x, const DDim &dims_y) { return true; } +inline DDim GetOutputDims(const DDim &s_dims, const DDim &l_dims) { + if (s_dims.size() > l_dims.size()) { + return GetOutputDims(l_dims, s_dims); + } + std::vector shapes = phi::vectorize(l_dims); + for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) { + int64_t s = s_dims[i]; + int64_t l = l_dims[j]; + if (s != l) { + if (l == 1) { + shapes[j] = s; + } else if (s != 1) { + PADDLE_THROW(errors::InvalidArgument( + "The shape of tensor a %s:%d must match shape of tensor b " + "%s:%d.", + s_dims.to_str(), + i, + l_dims.to_str(), + j)); + } + } + } + return phi::make_ddim(shapes); +} + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/impl/eye_kernel_impl.h b/paddle/phi/kernels/impl/eye_kernel_impl.h index 453652273a2..f4041f921fd 100644 --- a/paddle/phi/kernels/impl/eye_kernel_impl.h +++ b/paddle/phi/kernels/impl/eye_kernel_impl.h @@ -36,7 +36,7 @@ template void EyeKernel(const Context& ctx, int64_t num_rows, int64_t num_columns, - int dtype, + DataType dtype, DenseTensor* out) { auto num = num_columns; if (num == -1) { -- GitLab