From 95280a368b9f41e6f5ca3feff138ff82d6a56bf9 Mon Sep 17 00:00:00 2001 From: Sing_chan <51314274+betterpig@users.noreply.github.com> Date: Wed, 23 Feb 2022 19:41:33 +0800 Subject: [PATCH] move trunc_op's infere shape to phi (#39772) * move trunc_op's infere shape * modify according to risheng's comment --- paddle/fluid/operators/trunc_op.cc | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/trunc_op.cc b/paddle/fluid/operators/trunc_op.cc index bd3dc00299..54f4deac80 100644 --- a/paddle/fluid/operators/trunc_op.cc +++ b/paddle/fluid/operators/trunc_op.cc @@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -21,14 +23,6 @@ namespace operators { class TruncOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "trunc"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "trunc"); - auto input_dims = ctx->GetInputDim("X"); - ctx->SetOutputDim("Out", input_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } }; class TruncOpMaker : public framework::OpProtoAndCheckerMaker { @@ -75,9 +69,13 @@ class TruncGradOpMaker : public framework::SingleGradOpMaker { } // namespace operators } // namespace paddle +DELCARE_INFER_SHAPE_FUNCTOR(trunc, TruncInferShapeFunctor, + PT_INFER_META(phi::UnchangedInferMeta)); + namespace ops = paddle::operators; REGISTER_OPERATOR(trunc, ops::TruncOp, ops::TruncOpMaker, ops::TruncGradOpMaker, - ops::TruncGradOpMaker); + ops::TruncGradOpMaker, + TruncInferShapeFunctor); REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp); -- GitLab