diff --git a/paddle/fluid/operators/trunc_op.cc b/paddle/fluid/operators/trunc_op.cc index bd3dc002990a7cf3af738eb2d914b3fc3dd9e79a..54f4deac80a74e2e471036c2e25d08a582e29a9d 100644 --- a/paddle/fluid/operators/trunc_op.cc +++ b/paddle/fluid/operators/trunc_op.cc @@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -21,14 +23,6 @@ namespace operators { class TruncOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "trunc"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "trunc"); - auto input_dims = ctx->GetInputDim("X"); - ctx->SetOutputDim("Out", input_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } }; class TruncOpMaker : public framework::OpProtoAndCheckerMaker { @@ -75,9 +69,13 @@ class TruncGradOpMaker : public framework::SingleGradOpMaker { } // namespace operators } // namespace paddle +DELCARE_INFER_SHAPE_FUNCTOR(trunc, TruncInferShapeFunctor, + PT_INFER_META(phi::UnchangedInferMeta)); + namespace ops = paddle::operators; REGISTER_OPERATOR(trunc, ops::TruncOp, ops::TruncOpMaker, ops::TruncGradOpMaker, - ops::TruncGradOpMaker); + ops::TruncGradOpMaker, + TruncInferShapeFunctor); REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp);