diff --git a/paddle/fluid/operators/cumprod_op.cc b/paddle/fluid/operators/cumprod_op.cc index 90910bbbb2050bad85d10e0467a099c42030c084..889cdac8f6882744c7a7044861d237964e6f6ac0 100644 --- a/paddle/fluid/operators/cumprod_op.cc +++ b/paddle/fluid/operators/cumprod_op.cc @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -21,14 +23,6 @@ namespace operators { class CumprodOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Cumprod"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Cumprod"); - - ctx->ShareDim("X", "Out"); - ctx->ShareLoD("X", "Out"); - } }; class CumprodOpMaker : public framework::OpProtoAndCheckerMaker { @@ -82,9 +76,12 @@ class CumprodGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(cumprod, CumprodInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); REGISTER_OPERATOR(cumprod, ops::CumprodOp, ops::CumprodOpMaker, ops::CumprodGradOpMaker, - ops::CumprodGradOpMaker); + ops::CumprodGradOpMaker, + CumprodInferShapeFunctor); REGISTER_OPERATOR(cumprod_grad, ops::CumprodGradOp); diff --git a/paddle/fluid/operators/kldiv_loss_op.cc b/paddle/fluid/operators/kldiv_loss_op.cc index dcd98054b05c314da0884e8dc6be358d3afb0483..67c1942ea0b41e480c524f9c188b2a82649ba44e 100644 --- a/paddle/fluid/operators/kldiv_loss_op.cc +++ b/paddle/fluid/operators/kldiv_loss_op.cc @@ -11,7 +11,9 @@ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -21,44 +23,6 @@ using framework::Tensor; class KLDivLossOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "KLDivLoss"); - OP_INOUT_CHECK(ctx->HasInput("Target"), "Input", "Target", "KLDivLoss"); - OP_INOUT_CHECK(ctx->HasOutput("Loss"), "Output", "Loss", "KLDivLoss"); - - auto dim_x = ctx->GetInputDim("X"); - auto dim_target = ctx->GetInputDim("Target"); - PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(), - platform::errors::InvalidArgument( - "Input(X) rank and Input(Target) rank should be " - "same, but received X rank(%d) != Target rank(%d)", - dim_x.size(), dim_target.size())); - for (int i = 0; i < dim_x.size(); i++) { - if (ctx->IsRuntime() || (dim_x[i] > 0 && dim_target[i] > 0)) { - PADDLE_ENFORCE_EQ( - dim_x[i], dim_target[i], - platform::errors::InvalidArgument( - "Input(X) and Input(Target) should in same shape. but received " - "X dimension[%d](%d) != Target dimension[%d](%d)", - i, dim_x[i], i, dim_target[i])); - } - } - - auto reduction = ctx->Attrs().Get("reduction"); - - auto reduction_valid = "mean" == reduction || "sum" == reduction || - "batchmean" == reduction || "none" == reduction; - PADDLE_ENFORCE_EQ( - reduction_valid, true, - platform::errors::InvalidArgument( - "Attr(reduction) can only be 'none'|'batchmean'|'sum'|'mean'.")); - - if ("none" == reduction) { - ctx->SetOutputDim("Loss", dim_x); - } else { - ctx->SetOutputDim("Loss", {1}); - } - } protected: framework::OpKernelType GetExpectedKernelType( @@ -171,8 +135,12 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(KLDivLossGradNoNeedBufferVarInferer, "X"); } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(kldiv_loss, KLDivInferShapeFunctor, + PD_INFER_META(phi::KLDivInferMeta)); + REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker, ops::KLDivLossOpGradMaker, - ops::KLDivLossOpGradMaker); + ops::KLDivLossOpGradMaker, + KLDivInferShapeFunctor); REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad, ops::KLDivLossGradNoNeedBufferVarInferer); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index d103bef2d9ed331b7b9e4d3489e11fbc2c720072..4c1d169615b1c820673a1ddc26f58ffb607e13a4 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -73,6 +73,51 @@ void AllValueCompareInferMeta(const MetaTensor& x, out->set_dtype(DataType::BOOL); } +void KLDivInferMeta(const MetaTensor& x, + const MetaTensor& label, + const std::string& reduction, + MetaTensor* out, + MetaConfig config) { + auto dim_x = x.dims(); + auto dim_target = label.dims(); + PADDLE_ENFORCE_EQ(dim_x.size(), + dim_target.size(), + phi::errors::InvalidArgument( + "Input(X) rank and Input(Target) rank should be " + "same, but received X rank(%d) != Target rank(%d)", + dim_x.size(), + dim_target.size())); + for (int i = 0; i < dim_x.size(); i++) { + if (config.is_runtime || (dim_x[i] > 0 && dim_target[i] > 0)) { + PADDLE_ENFORCE_EQ( + dim_x[i], + dim_target[i], + phi::errors::InvalidArgument( + "Input(X) and Input(Target) should in same shape. but received " + "X dimension[%d](%d) != Target dimension[%d](%d)", + i, + dim_x[i], + i, + dim_target[i])); + } + } + + auto reduction_valid = "mean" == reduction || "sum" == reduction || + "batchmean" == reduction || "none" == reduction; + PADDLE_ENFORCE_EQ( + reduction_valid, + true, + phi::errors::InvalidArgument( + "Attr(reduction) can only be 'none'|'batchmean'|'sum'|'mean'.")); + + if ("none" == reduction) { + out->set_dims(dim_x); + } else { + out->set_dims({1}); + } + out->set_dtype(x.dtype()); +} + void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { out->share_meta(x); } diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 5d93bae316238ce02c33351ea4a320b3fc79a877..40641ea48581b5eeee81085506ecb8f444761a86 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -35,6 +35,12 @@ void AllValueCompareInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void KLDivInferMeta(const MetaTensor& x, + const MetaTensor& label, + const std::string& reduction, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); void BCELossInferMeta(const MetaTensor& input,