diff --git a/paddle/fluid/operators/matrix_power_op.cc b/paddle/fluid/operators/matrix_power_op.cc index cdf204628b638f877c92e35a8941487aa39b5427..56f65340ea999f48702294f912c4354d83990881 100644 --- a/paddle/fluid/operators/matrix_power_op.cc +++ b/paddle/fluid/operators/matrix_power_op.cc @@ -14,8 +14,11 @@ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -23,26 +26,6 @@ namespace operators { class MatrixPowerOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "matrix_power"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "matrix_power"); - auto dims = ctx->GetInputDim("X"); - auto n_dim = dims.size(); - PADDLE_ENFORCE_GE(n_dim, 2, - platform::errors::InvalidArgument( - "The Input(X) should have at least 2 dimensions. But " - "received a %d dimension tensor.", - n_dim)); - PADDLE_ENFORCE_EQ(dims[n_dim - 2], dims[n_dim - 1], - platform::errors::InvalidArgument( - "The inner-most 2 dimensions of Input(X) all should " - "be square matrices " - "But received X's shape[-2] = %d and shape[-1] = %d.", - dims[n_dim - 2], dims[n_dim - 1])); - ctx->SetOutputDim("Out", dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } }; class MatrixPowerOpMaker : public framework::OpProtoAndCheckerMaker { @@ -116,9 +99,14 @@ class MatrixPowerGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(matrix_power, MatrixPowerInferShapeFunctor, + PD_INFER_META(phi::MatrixPowerInferMeta)); + REGISTER_OPERATOR(matrix_power, ops::MatrixPowerOp, ops::MatrixPowerOpMaker, ops::MatrixPowerOpInferVarType, ops::MatrixPowerGradOpMaker, - ops::MatrixPowerGradOpMaker); + ops::MatrixPowerGradOpMaker, + MatrixPowerInferShapeFunctor); REGISTER_OPERATOR(matrix_power_grad, ops::MatrixPowerGradOp); diff --git a/paddle/fluid/operators/multi_dot_op.cc b/paddle/fluid/operators/multi_dot_op.cc index b309e1b87ef9033bd4302cdad4ea60a64cbf02eb..5b107ce643df33af79230c30d784d1ad84c26666 100644 --- a/paddle/fluid/operators/multi_dot_op.cc +++ b/paddle/fluid/operators/multi_dot_op.cc @@ -16,77 +16,19 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/kernels/funcs/blas/blas.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -/** - * @brief compute the output shape and check the input shape valid or not - */ -inline framework::DDim ComputeAndCheckShape( - const bool is_runtime, const std::vector& inputs_dims) { - const size_t n = inputs_dims.size(); - auto first_dim = inputs_dims[0]; - - bool is_vector = false; - framework::DDim out_dim; - - PADDLE_ENFORCE_LT( - first_dim.size(), static_cast(3), - platform::errors::InvalidArgument( - "multi_dot: the first input tensor must be 1D or 2D but got[%d]!", - static_cast(first_dim.size()))); - - // If the first tensor is 1D of size n view it as a row vector (1, n) - if (first_dim.size() == 1) { - first_dim = phi::make_ddim({1, static_cast(first_dim[0])}); - is_vector = true; - } - - auto last_dim = inputs_dims[n - 1]; - PADDLE_ENFORCE_LT( - last_dim.size(), static_cast(3), - platform::errors::InvalidArgument( - "the last input tensor of multi_dot must be 1D or 2D but got[%d]!", - static_cast(first_dim.size()))); - - // If the last tensor is 1D of size n view it as a column vector (n, 1) - if (last_dim.size() == 1) { - last_dim = phi::make_ddim({static_cast(last_dim[0]), 1}); - out_dim = is_vector ? phi::make_ddim({1}) : phi::make_ddim({first_dim[0]}); - } else { - out_dim = is_vector ? phi::make_ddim({last_dim[1]}) - : phi::make_ddim({first_dim[0], last_dim[1]}); - } - - auto width = first_dim[1]; - for (size_t i = 1; i < n - 1; i++) { - PADDLE_ENFORCE_EQ(inputs_dims[i].size(), static_cast(2), - platform::errors::InvalidArgument( - "the input tensor of multi_dot op must be 2D.")); - - const auto& tmp_dim = inputs_dims[i]; - PADDLE_ENFORCE_EQ( - tmp_dim[0], width, - platform::errors::InvalidArgument( - "the input matrix does not meet the multiplication requirements.")); - width = tmp_dim[1]; - } - - PADDLE_ENFORCE_EQ( - last_dim[0], width, - platform::errors::InvalidArgument( - "the input matrix does not meet the multiplication requirements.")); - - return out_dim; -} - class MultiDotOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -105,22 +47,6 @@ If the first argument is 1-D it is treated as a row vector. If the last argument class MultiDotOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "multi_dot"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "multi_dot"); - - auto inputs_dims = ctx->GetInputsDim("X"); - - const size_t inputs_num = inputs_dims.size(); - PADDLE_ENFORCE_GT( - inputs_num, static_cast(1), - platform::errors::InvalidArgument( - "The number of input tensors in multi_dot op should > 1.")); - auto out_dims = ComputeAndCheckShape(ctx->IsRuntime(), inputs_dims); - ctx->SetOutputDim("Out", out_dims); - ctx->ShareLoD("X", "Out"); - } }; class MultiDotOpGrad : public framework::OperatorWithKernel { @@ -171,9 +97,15 @@ class MultiDotOpDoubleGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(multi_dot, MultiDotInferShapeFunctor, + PD_INFER_META(phi::MultiDotInferMeta)); + REGISTER_OPERATOR(multi_dot, ops::MultiDotOp, ops::MultiDotOpMaker, ops::MultiDotOpGradMaker, - ops::MultiDotOpGradMaker); + ops::MultiDotOpGradMaker, + MultiDotInferShapeFunctor); + REGISTER_OPERATOR(multi_dot_grad, ops::MultiDotOpGrad, ops::MultiDotOpDoubleGradMaker, ops::MultiDotOpDoubleGradMaker); diff --git a/paddle/fluid/operators/shape_op.cc b/paddle/fluid/operators/shape_op.cc index e2c8359beb1290f7b1b592c1ff24b15986f41f73..9001ce5d51dece5c6cee481f3f6f92e69c302c2b 100644 --- a/paddle/fluid/operators/shape_op.cc +++ b/paddle/fluid/operators/shape_op.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -22,17 +25,6 @@ class ShapeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, - platform::errors::InvalidArgument( - "Input (Input) of get_shape op should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::InvalidArgument( - "Output (Out) of get_shape op should not be null.")); - auto in_dim = ctx->GetInputDim("Input"); - ctx->SetOutputDim("Out", {in_dim.size()}); - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = @@ -89,7 +81,12 @@ Return the shape of the input. namespace ops = paddle::operators; namespace plat = paddle::platform; + +DECLARE_INFER_SHAPE_FUNCTOR(shape, ShapeInferShapeFunctor, + PD_INFER_META(phi::ShapeInferMeta)); + REGISTER_OPERATOR( shape, ops::ShapeOp, ops::ShapeOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker, + ShapeInferShapeFunctor); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 84441ed8b740be172ddaa7de3fc23ad420ebf077..ef75ab573c6d9bd5c65fad747a28f2c704257371 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -369,6 +369,79 @@ void ConcatInferMeta(const std::vector& x, out->share_lod(*x.at(0)); } +void MultiDotInferMeta(const std::vector& x, MetaTensor* out) { + auto inputs_dims = GetMetaTensorsDim(x); + + const size_t inputs_num = inputs_dims.size(); + PADDLE_ENFORCE_GT( + inputs_num, + static_cast(1), + phi::errors::InvalidArgument( + "The number of input tensors in multi_dot op should > 1.")); + + const size_t n = inputs_dims.size(); + auto first_dim = inputs_dims[0]; + + bool is_vector = false; + phi::DDim out_dim; + + PADDLE_ENFORCE_LT( + first_dim.size(), + static_cast(3), + phi::errors::InvalidArgument( + "multi_dot: the first input tensor must be 1D or 2D but got[%d]!", + static_cast(first_dim.size()))); + + // If the first tensor is 1D of size n view it as a row vector (1, n) + if (first_dim.size() == 1) { + first_dim = phi::make_ddim({1, static_cast(first_dim[0])}); + is_vector = true; + } + + auto last_dim = inputs_dims[n - 1]; + PADDLE_ENFORCE_LT( + last_dim.size(), + static_cast(3), + phi::errors::InvalidArgument( + "the last input tensor of multi_dot must be 1D or 2D but got[%d]!", + static_cast(first_dim.size()))); + + // If the last tensor is 1D of size n view it as a column vector (n, 1) + if (last_dim.size() == 1) { + last_dim = phi::make_ddim({static_cast(last_dim[0]), 1}); + out_dim = is_vector ? phi::make_ddim({1}) : phi::make_ddim({first_dim[0]}); + } else { + out_dim = is_vector ? phi::make_ddim({last_dim[1]}) + : phi::make_ddim({first_dim[0], last_dim[1]}); + } + + auto width = first_dim[1]; + for (size_t i = 1; i < n - 1; i++) { + PADDLE_ENFORCE_EQ(inputs_dims[i].size(), + static_cast(2), + phi::errors::InvalidArgument( + "the input tensor of multi_dot op must be 2D.")); + + const auto& tmp_dim = inputs_dims[i]; + PADDLE_ENFORCE_EQ( + tmp_dim[0], + width, + phi::errors::InvalidArgument( + "the input matrix does not meet the multiplication requirements.")); + width = tmp_dim[1]; + } + + PADDLE_ENFORCE_EQ( + last_dim[0], + width, + phi::errors::InvalidArgument( + "the input matrix does not meet the multiplication requirements.")); + + out->set_dims(out_dim); + out->set_dtype(x.at(0)->dtype()); + out->share_lod(*x.at(0)); +} + void PsroiPoolInferMeta(const MetaTensor& x, const MetaTensor& rois, paddle::optional rois_num, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index c11843212ed33fd8170e6677a4d6e0ad95b730dc..0bdd35d5f58e8e9d5c3dd7956897bac0adbdf550 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -70,6 +70,8 @@ void ConcatInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void MultiDotInferMeta(const std::vector& x, MetaTensor* out); + void PsroiPoolInferMeta(const MetaTensor& x, const MetaTensor& rois, paddle::optional rois_num, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 4d1cb42bd59f072e5926b237528a742c231bcdcf..752abae1b0333f46a749dca586936b0fca095720 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -554,6 +554,28 @@ void IsfiniteInferMeta(const MetaTensor& x, MetaTensor* out) { out->set_dtype(DataType::BOOL); } +void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) { + auto dims = x.dims(); + auto n_dim = dims.size(); + PADDLE_ENFORCE_GE(n_dim, + 2, + phi::errors::InvalidArgument( + "The Input(X) should have at least 2 dimensions. But " + "received a %d dimension tensor.", + n_dim)); + PADDLE_ENFORCE_EQ(dims[n_dim - 2], + dims[n_dim - 1], + phi::errors::InvalidArgument( + "The inner-most 2 dimensions of Input(X) all should " + "be square matrices " + "But received X's shape[-2] = %d and shape[-1] = %d.", + dims[n_dim - 2], + dims[n_dim - 1])); + out->set_dims(dims); + out->share_lod(x); + out->set_dtype(x.dtype()); +} + void MaxPoolWithIndexInferMeta(const MetaTensor& x, const std::vector& kernel_size, const std::vector& strides, @@ -994,6 +1016,12 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ReshapeInferMeta(x, shape, out, config); } +void ShapeInferMeta(const MetaTensor& input, MetaTensor* out) { + auto in_dim = input.dims(); + out->set_dims(phi::make_ddim({in_dim.size()})); + out->set_dtype(DataType::INT32); +} + void ShardIndexInferMeta(const MetaTensor& in, int index_num, int nshards, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 75fb9fadf82dc87ac18814e0674e5012fea95ec4..a9aefd1f12d67e994f6cc92c4bbb849654bb00b9 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -98,6 +98,8 @@ void IsEmptyInferMeta(const MetaTensor& x, MetaTensor* out); void IsfiniteInferMeta(const MetaTensor& input, MetaTensor* out); +void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out); + void MaxPoolWithIndexInferMeta(const MetaTensor& x, const std::vector& kernel_size, const std::vector& strides, @@ -162,6 +164,8 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void ShapeInferMeta(const MetaTensor& input, MetaTensor* out); + void ShardIndexInferMeta(const MetaTensor& in, int index_num, int nshards, diff --git a/paddle/phi/kernels/cpu/shape_kernel.cc b/paddle/phi/kernels/cpu/shape_kernel.cc deleted file mode 100644 index 073dc63b2a4348d4091af8c285f9ddebd799acc5..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/cpu/shape_kernel.cc +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/shape_kernel.h" -#include "paddle/phi/kernels/impl/shape_kernel_impl.h" - -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" - -PD_REGISTER_KERNEL(shape, - CPU, - ALL_LAYOUT, - phi::ShapeKernel, - bool, - int, - int8_t, - uint8_t, - int64_t, - float, - double, - phi::dtype::complex, - phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/shape_kernel_impl.h b/paddle/phi/kernels/impl/shape_kernel_impl.h deleted file mode 100644 index 982cfb33f6b14fc14c7c58ff8c4548a4cdbd3b3b..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/impl/shape_kernel_impl.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -template -void ShapeKernel(const Context& ctx, - const DenseTensor& input, - DenseTensor* out) { - auto in_var = &input; - phi::DDim in_dims; - in_dims = in_var->dims(); - auto out_t = out; - out_t->Resize({in_dims.size()}); - auto out_data = ctx.template HostAlloc(out_t); - for (int i = 0; i < in_dims.size(); ++i) { - out_data[i] = in_dims[i]; - } -} - -} // namespace phi diff --git a/paddle/phi/kernels/selected_rows/shape_kernel.cc b/paddle/phi/kernels/selected_rows/shape_kernel.cc index 9bcd5d8544e2d73961d72115023446d427e8895e..67126d82042b28de8c560a55046e50029153290d 100644 --- a/paddle/phi/kernels/selected_rows/shape_kernel.cc +++ b/paddle/phi/kernels/selected_rows/shape_kernel.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/shape_kernel.h" namespace phi { namespace sr { @@ -25,15 +26,7 @@ template void ShapeKernel(const Context& ctx, const SelectedRows& input, DenseTensor* out) { - auto in_var = input; - phi::DDim in_dims; - in_dims = in_var.value().dims(); - auto out_t = out; - out_t->Resize({in_dims.size()}); - auto out_data = ctx.template HostAlloc(out_t); - for (int i = 0; i < in_dims.size(); ++i) { - out_data[i] = in_dims[i]; - } + phi::ShapeKernel(ctx, input.value(), out); } } // namespace sr diff --git a/paddle/phi/kernels/gpu/shape_kernel.cu b/paddle/phi/kernels/shape_kernel.cc similarity index 53% rename from paddle/phi/kernels/gpu/shape_kernel.cu rename to paddle/phi/kernels/shape_kernel.cc index 39b6eaeaef2a8e80d204941dc1f3ac92907aa786..dd26a7edc9cdd8e1917bb5d88e957b3e7d545f93 100644 --- a/paddle/phi/kernels/gpu/shape_kernel.cu +++ b/paddle/phi/kernels/shape_kernel.cc @@ -13,12 +13,43 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/kernels/shape_kernel.h" - -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/common/float16.h" +#include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/shape_kernel_impl.h" +namespace phi { + +template +void ShapeKernel(const Context& ctx, + const DenseTensor& input, + DenseTensor* out) { + auto in_var = &input; + phi::DDim in_dims; + in_dims = in_var->dims(); + auto out_t = out; + out_t->Resize({in_dims.size()}); + auto out_data = ctx.template HostAlloc(out_t); + for (int i = 0; i < in_dims.size(); ++i) { + out_data[i] = in_dims[i]; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(shape, + CPU, + ALL_LAYOUT, + phi::ShapeKernel, + bool, + int, + int8_t, + uint8_t, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_REGISTER_KERNEL(shape, GPU, ALL_LAYOUT, @@ -33,3 +64,4 @@ PD_REGISTER_KERNEL(shape, phi::dtype::complex, phi::dtype::complex, phi::dtype::float16) {} +#endif