From d35b5b585278025d8f732ca8244157cea70df5f3 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 11 Mar 2022 10:10:52 +0800 Subject: [PATCH] [phi] [infershape] transfer nll_loss infer shape into phi (#40375) * transfer nll_loss infershape into phi --- paddle/fluid/operators/nll_loss_op.cc | 78 ++------------------------ paddle/phi/infermeta/ternary.cc | 80 +++++++++++++++++++++++++++ paddle/phi/infermeta/ternary.h | 9 +++ 3 files changed, 95 insertions(+), 72 deletions(-) diff --git a/paddle/fluid/operators/nll_loss_op.cc b/paddle/fluid/operators/nll_loss_op.cc index 6c35ad29e9..a4e1f7b309 100644 --- a/paddle/fluid/operators/nll_loss_op.cc +++ b/paddle/fluid/operators/nll_loss_op.cc @@ -14,7 +14,9 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/ternary.h" namespace paddle { namespace operators { @@ -23,77 +25,6 @@ class NLLLossOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "NLLLoss"); - OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "NLLLoss"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "NLLLoss"); - OP_INOUT_CHECK(ctx->HasOutput("Total_weight"), "Output", "Total_weight", - "NLLLoss"); - - auto x_dims = ctx->GetInputDim("X"); - auto label_dims = ctx->GetInputDim("Label"); - auto reduction = ctx->Attrs().Get("reduction"); - - PADDLE_ENFORCE_EQ(x_dims.size() == 2 || x_dims.size() == 4, true, - platform::errors::InvalidArgument( - "The tensor rank of Input(X) must be 2 or 4.")); - bool contain_unknown_dim = phi::contain_unknown_dim(x_dims) || - phi::contain_unknown_dim(label_dims); - bool check = ctx->IsRuntime() || !contain_unknown_dim; - if (check) { - PADDLE_ENFORCE_EQ( - x_dims[0], label_dims[0], - platform::errors::InvalidArgument( - "ShapeError: Expected input batch_size to match label batch_size," - "But received: the Input(x) batch_size is [%s], the Input(label) " - " batch_size is [%s].", - x_dims[0], label_dims[0])); - if (ctx->HasInput("Weight")) { - auto w_dims = ctx->GetInputDim("Weight"); - PADDLE_ENFORCE_EQ(w_dims.size(), 1, - platform::errors::InvalidArgument( - "Input(Weight) should be a 1D tensor.")); - PADDLE_ENFORCE_EQ( - x_dims[1], w_dims[0], - platform::errors::InvalidArgument( - "Expected input tensor Weight's size should equal " - "to the first dimension of the input tensor X. But received " - "Weight's " - "size is %d, the first dimension of input X is %d", - w_dims[0], x_dims[1])); - } - } - if (x_dims.size() == 2) { - if (reduction == "none") { - ctx->SetOutputDim("Out", {x_dims[0]}); - } else { - ctx->SetOutputDim("Out", {1}); - } - } else if (x_dims.size() == 4) { - PADDLE_ENFORCE_EQ(label_dims.size(), 3, - platform::errors::InvalidArgument( - "Expected Input(Lable) dimensions=3, received %d.", - label_dims.size())); - auto input0 = x_dims[0]; - auto input2 = x_dims[2]; - auto input3 = x_dims[3]; - auto label0 = label_dims[0]; - auto label1 = label_dims[1]; - auto label2 = label_dims[2]; - PADDLE_ENFORCE_EQ( - input0 == label0 && input2 == label1 && input3 == label2, true, - platform::errors::InvalidArgument("Input(X) tensor shape should " - "match to Input(Label) tensor " - "shape.")); - if (reduction == "none") { - ctx->SetOutputDim("Out", {x_dims[0], x_dims[2], x_dims[3]}); - } else { - ctx->SetOutputDim("Out", {1}); - } - } - ctx->SetOutputDim("Total_weight", {1}); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -259,8 +190,11 @@ class NLLLossGradMaker : public framework::SingleGradOpMaker { } // namespace operators } // namespace paddle +DECLARE_INFER_SHAPE_FUNCTOR(nll_loss, NllLossRawInferShapeFunctor, + PD_INFER_META(phi::NllLossRawInferMeta)); namespace ops = paddle::operators; REGISTER_OPERATOR(nll_loss, ops::NLLLossOp, ops::NLLLossOpMaker, ops::NLLLossGradMaker, - ops::NLLLossGradMaker); + ops::NLLLossGradMaker, + NllLossRawInferShapeFunctor); REGISTER_OPERATOR(nll_loss_grad, ops::NLLLossGradOp); diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 813c7243b3..88ac2cb0f8 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -89,6 +89,86 @@ void AddmmInferMeta(const MetaTensor& input, out->set_dtype(input.dtype()); } +void NllLossRawInferMeta(const MetaTensor& input, + const MetaTensor& label, + paddle::optional weight, + int64_t ignore_index, + const std::string& reduction, + MetaTensor* out, + MetaTensor* total_weight, + MetaConfig config) { + auto x_dims = input.dims(); + auto label_dims = label.dims(); + PADDLE_ENFORCE_EQ(x_dims.size() == 2 || x_dims.size() == 4, + true, + phi::errors::InvalidArgument( + "The tensor rank of Input(X) must be 2 or 4.")); + bool contain_unknown_dim = + phi::contain_unknown_dim(x_dims) || phi::contain_unknown_dim(label_dims); + bool check = config.is_runtime || !contain_unknown_dim; + if (check) { + PADDLE_ENFORCE_EQ( + x_dims[0], + label_dims[0], + phi::errors::InvalidArgument( + "ShapeError: Expected input batch_size to match label batch_size," + "But received: the Input(x) batch_size is [%s], the Input(label) " + " batch_size is [%s].", + x_dims[0], + label_dims[0])); + if (weight.get_ptr() != nullptr) { + auto w_dims = weight->dims(); + PADDLE_ENFORCE_EQ( + w_dims.size(), + 1, + phi::errors::InvalidArgument("Input(Weight) should be a 1D tensor.")); + PADDLE_ENFORCE_EQ( + x_dims[1], + w_dims[0], + phi::errors::InvalidArgument( + "Expected input tensor Weight's size should equal " + "to the first dimension of the input tensor X. But received " + "Weight's " + "size is %d, the first dimension of input X is %d", + w_dims[0], + x_dims[1])); + } + } + if (x_dims.size() == 2) { + if (reduction == "none") { + out->set_dims({x_dims[0]}); + } else { + out->set_dims({1}); + } + } else if (x_dims.size() == 4) { + PADDLE_ENFORCE_EQ(label_dims.size(), + 3, + phi::errors::InvalidArgument( + "Expected Input(Lable) dimensions=3, received %d.", + label_dims.size())); + auto input0 = x_dims[0]; + auto input2 = x_dims[2]; + auto input3 = x_dims[3]; + auto label0 = label_dims[0]; + auto label1 = label_dims[1]; + auto label2 = label_dims[2]; + PADDLE_ENFORCE_EQ( + input0 == label0 && input2 == label1 && input3 == label2, + true, + phi::errors::InvalidArgument("Input(X) tensor shape should " + "match to Input(Label) tensor " + "shape.")); + if (reduction == "none") { + out->set_dims({x_dims[0], x_dims[2], x_dims[3]}); + } else { + out->set_dims({1}); + } + } + total_weight->set_dims({1}); + out->set_dtype(input.dtype()); + total_weight->set_dtype(input.dtype()); +} + void ScatterInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& updates, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 2ccba1b89f..c9a7e78db7 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -56,6 +56,15 @@ void ScatterInferMeta(const MetaTensor& x, bool overwrite, MetaTensor* out); +void NllLossRawInferMeta(const MetaTensor& input, + const MetaTensor& label, + paddle::optional weight, + int64_t ignore_index, + const std::string& reduction, + MetaTensor* out, + MetaTensor* total_weight, + MetaConfig config = MetaConfig()); + void ScatterNdAddInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& updates, -- GitLab