diff --git a/paddle/fluid/operators/batch_size_like.h b/paddle/fluid/operators/batch_size_like.h index 530a9e8066d8c776fc458fdcf076e894934cd8d4..facb4cd82542b251695087ff2d129606199bb7a0 100644 --- a/paddle/fluid/operators/batch_size_like.h +++ b/paddle/fluid/operators/batch_size_like.h @@ -15,61 +15,32 @@ limitations under the License. */ #pragma once #include #include +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { namespace operators { +using MetaTensor = framework::CompatMetaTensor; + class BatchSizeLikeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", Type()); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", Type()); - auto &shape = ctx->Attrs().Get>("shape"); - PADDLE_ENFORCE_GT(shape.size(), 0, - platform::errors::InvalidArgument( - "Shape size must be larger than 0, but received: %s.", - shape.size())); - std::vector shape_int64(shape.size(), 0); - std::transform(shape.begin(), shape.end(), shape_int64.begin(), - [](int a) { return static_cast(a); }); - auto output_dim = phi::make_ddim(shape_int64); - - int input_dim_idx = ctx->Attrs().Get("input_dim_idx"); - int input_dim_size = static_cast(ctx->GetInputDim("Input").size()); - PADDLE_ENFORCE_GE(input_dim_idx, 0, - platform::errors::InvalidArgument( - "Input dimension index must be larger " - "equal than 0, but received: %s.", - input_dim_idx)); - PADDLE_ENFORCE_GT(input_dim_size, input_dim_idx, - platform::errors::InvalidArgument( - "Input dimension size must be larger than " - "input dimension index, but received input " - "dimension size: %s, input dimension index: %s.", - input_dim_size, input_dim_idx)); - - int output_dim_idx = ctx->Attrs().Get("output_dim_idx"); - int output_dim_size = static_cast(shape.size()); - PADDLE_ENFORCE_GE(output_dim_idx, 0, - platform::errors::InvalidArgument( - "Output dimension index must be larger " - "equal than 0, but received: %s.", - output_dim_idx)); - PADDLE_ENFORCE_GT( - output_dim_size, output_dim_idx, - platform::errors::InvalidArgument( - "Output dimension size must be larger than output dimension index, " - "but received output dimension size: %s, output dimension index: " - "%s.", - output_dim_size, output_dim_idx)); - - output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx]; - ctx->SetOutputDim("Out", output_dim); + MetaTensor x(ctx->GetInputVarPtrs("Input")[0], ctx->IsRuntime()); + MetaTensor out(ctx->GetOutputVarPtrs("Out")[0], ctx->IsRuntime()); + auto& shape = ctx->Attrs().Get>("shape"); + int x_batch_size_dim = ctx->Attrs().Get("input_dim_idx"); + int out_batch_size_dim = ctx->Attrs().Get("output_dim_idx"); + phi::BatchSizeLikeInferMeta(x, shape, x_batch_size_dim, out_batch_size_dim, + &out); } }; diff --git a/paddle/fluid/operators/fill_constant_batch_size_like_op.cc b/paddle/fluid/operators/fill_constant_batch_size_like_op.cc index 57e7cbb74079ed44a3f5554cda00243dc51f3a31..950a50e52439cd5a709cf07942507507c926d240 100644 --- a/paddle/fluid/operators/fill_constant_batch_size_like_op.cc +++ b/paddle/fluid/operators/fill_constant_batch_size_like_op.cc @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/operators/batch_size_like.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -61,9 +64,13 @@ obtained from the `input` tensor. } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(fill_constant_batch_size_like, + FillConstantBatchSizeLikeInferShapeFunctor, + PD_INFER_META(phi::FullBatchSizeLikeInferMeta)); REGISTER_OPERATOR( fill_constant_batch_size_like, ops::FillConstantBatchSizeLikeOp, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, ops::FillConstantBatchSizeLikeOpMaker, - ops::BatchSizeLikeNoNeedBufferVarsInferer); + ops::BatchSizeLikeNoNeedBufferVarsInferer, + FillConstantBatchSizeLikeInferShapeFunctor); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index a67cc270c25e5f1c69baef00553855893f637cc5..7fc807f28fbf2c4d486f9b4246652a86b9650922 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -133,6 +133,59 @@ void ArgsortInferMeta(const MetaTensor& input, indices->share_lod(input); } +void BatchSizeLikeInferMeta(const MetaTensor& x, + const std::vector& shape, + int x_batch_size_dim, + int out_batch_size_dim, + MetaTensor* out) { + PADDLE_ENFORCE_GT( + shape.size(), + 0UL, + phi::errors::InvalidArgument( + "Shape size must be larger than 0, but received: %s.", shape.size())); + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), [](int a) { + return static_cast(a); + }); + auto output_dim = phi::make_ddim(shape_int64); + + int input_dim_size = static_cast(x.dims().size()); + PADDLE_ENFORCE_GE( + x_batch_size_dim, + 0, + phi::errors::InvalidArgument("Input dimension index must be larger " + "equal than 0, but received: %s.", + x_batch_size_dim)); + PADDLE_ENFORCE_GT(input_dim_size, + x_batch_size_dim, + phi::errors::InvalidArgument( + "Input dimension size must be larger than " + "input dimension index, but received input " + "dimension size: %s, input dimension index: %s.", + input_dim_size, + x_batch_size_dim)); + + int output_dim_size = static_cast(shape.size()); + PADDLE_ENFORCE_GE( + out_batch_size_dim, + 0, + phi::errors::InvalidArgument("Output dimension index must be larger " + "equal than 0, but received: %s.", + out_batch_size_dim)); + PADDLE_ENFORCE_GT( + output_dim_size, + out_batch_size_dim, + phi::errors::InvalidArgument( + "Output dimension size must be larger than output dimension index, " + "but received output dimension size: %s, output dimension index: " + "%s.", + output_dim_size, + out_batch_size_dim)); + + output_dim[out_batch_size_dim] = x.dims()[x_batch_size_dim]; + out->set_dims(output_dim); +} + void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) { out->set_dims(x.dims()); out->set_dtype(out_dtype); @@ -413,6 +466,17 @@ void FlattenWithXShapeInferMeta(const MetaTensor& x, xshape->share_lod(x); } +void FullBatchSizeLikeInferMeta(const MetaTensor& x, + const std::vector& shape, + const Scalar& val, + DataType dtype, + int x_batch_size_dim, + int out_batch_size_dim, + MetaTensor* out) { + BatchSizeLikeInferMeta(x, shape, x_batch_size_dim, out_batch_size_dim, out); + out->set_dtype(dtype); +} + void GumbelSoftmaxInferMeta(const MetaTensor& x, float temperature, bool hard, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 697926b76aea16c86f5aac4783fafbeba75f08d9..fe11f7d44ab401e38d81a0637b5bc4ba1a6958bc 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -48,6 +48,12 @@ void ArgsortInferMeta(const MetaTensor& input, MetaTensor* output, MetaTensor* indices); +void BatchSizeLikeInferMeta(const MetaTensor& x, + const std::vector& shape, + int x_batch_size_dim, + int out_batch_size_dim, + MetaTensor* out); + void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out); @@ -92,6 +98,14 @@ void FlattenWithXShapeInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* xshape); +void FullBatchSizeLikeInferMeta(const MetaTensor& x, + const std::vector& shape, + const Scalar& val, + DataType dtype, + int x_batch_size_dim, + int out_batch_size_dim, + MetaTensor* out); + void GumbelSoftmaxInferMeta(const MetaTensor& x, float temperature, bool hard,