From abacc4cb1275abd5e942db3a849fcd0d83f9f9f8 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 4 Mar 2022 11:52:22 +0800 Subject: [PATCH] transfer selu infershape (#40137) --- paddle/fluid/operators/selu_op.cc | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/selu_op.cc b/paddle/fluid/operators/selu_op.cc index 88ef1f3ea4a..0372a79b967 100644 --- a/paddle/fluid/operators/selu_op.cc +++ b/paddle/fluid/operators/selu_op.cc @@ -16,7 +16,10 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/common_infer_shape_functions.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -28,10 +31,6 @@ class SeluOp : public framework::OperatorWithKernel { const framework::AttributeMap &attrs) : OperatorWithKernel(type, inputs, outputs, attrs) {} - void InferShape(framework::InferShapeContext *ctx) const override { - return UnaryOpUnchangedInferShape(ctx); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -121,7 +120,12 @@ class SeluGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(selu, SeluInferShapeFunctor, + PT_INFER_META(phi::UnchangedInferMeta)); + REGISTER_OPERATOR(selu, ops::SeluOp, ops::SeluOpMaker, ops::SeluOpInferVarType, ops::SeluGradMaker, - ops::SeluGradMaker); + ops::SeluGradMaker, + SeluInferShapeFunctor); + REGISTER_OPERATOR(selu_grad, ops::SeluGradOp); -- GitLab