diff --git a/paddle/fluid/operators/selu_op.cc b/paddle/fluid/operators/selu_op.cc index 88ef1f3ea4aa4d8d827a810026575c20e596b4e7..0372a79b967a48c63ad7912df264b1e519485c03 100644 --- a/paddle/fluid/operators/selu_op.cc +++ b/paddle/fluid/operators/selu_op.cc @@ -16,7 +16,10 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/common_infer_shape_functions.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -28,10 +31,6 @@ class SeluOp : public framework::OperatorWithKernel { const framework::AttributeMap &attrs) : OperatorWithKernel(type, inputs, outputs, attrs) {} - void InferShape(framework::InferShapeContext *ctx) const override { - return UnaryOpUnchangedInferShape(ctx); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -121,7 +120,12 @@ class SeluGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(selu, SeluInferShapeFunctor, + PT_INFER_META(phi::UnchangedInferMeta)); + REGISTER_OPERATOR(selu, ops::SeluOp, ops::SeluOpMaker, ops::SeluOpInferVarType, ops::SeluGradMaker, - ops::SeluGradMaker); + ops::SeluGradMaker, + SeluInferShapeFunctor); + REGISTER_OPERATOR(selu_grad, ops::SeluGradOp);