未验证 提交 abacc4cb 编写于 作者: X xiongkun 提交者: GitHub

transfer selu infershape (#40137)

上级 14e98a0f
...@@ -16,7 +16,10 @@ limitations under the License. */ ...@@ -16,7 +16,10 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -28,10 +31,6 @@ class SeluOp : public framework::OperatorWithKernel { ...@@ -28,10 +31,6 @@ class SeluOp : public framework::OperatorWithKernel {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
return UnaryOpUnchangedInferShape(ctx);
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
...@@ -121,7 +120,12 @@ class SeluGradOp : public framework::OperatorWithKernel { ...@@ -121,7 +120,12 @@ class SeluGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
DELCARE_INFER_SHAPE_FUNCTOR(selu, SeluInferShapeFunctor,
PT_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(selu, ops::SeluOp, ops::SeluOpMaker, ops::SeluOpInferVarType, REGISTER_OPERATOR(selu, ops::SeluOp, ops::SeluOpMaker, ops::SeluOpInferVarType,
ops::SeluGradMaker<paddle::framework::OpDesc>, ops::SeluGradMaker<paddle::framework::OpDesc>,
ops::SeluGradMaker<paddle::imperative::OpBase>); ops::SeluGradMaker<paddle::imperative::OpBase>,
SeluInferShapeFunctor);
REGISTER_OPERATOR(selu_grad, ops::SeluGradOp); REGISTER_OPERATOR(selu_grad, ops::SeluGradOp);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册