未验证 提交 8dcae0c5 编写于 作者: W wangchaochaohu 提交者: GitHub

register OPMaker and Infer Shape Check for fused_elementwise_add (#30259)

上级 924aac22
......@@ -287,6 +287,15 @@ class FusedElemwiseActivationGradMaker
}
};
class FusedElemwiseAddActivationMaker : public FusedElemwiseActivationMaker {};
template <typename T>
class FusedElemwiseAddActivationGradMaker
: public FusedElemwiseActivationGradMaker<T> {
public:
using FusedElemwiseActivationGradMaker<T>::FusedElemwiseActivationGradMaker;
};
class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -367,6 +376,53 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
}
};
class FusedElemwiseAddActivationOp : public FusedElemwiseActivationOp {
public:
using FusedElemwiseActivationOp::FusedElemwiseActivationOp;
void InferShape(framework::InferShapeContext *ctx) const override {
FusedElemwiseActivationOp::InferShape(ctx);
std::vector<std::string> functor_names =
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
bool elemntwise_add_detected = false;
for (auto names : functor_names) {
if (names == "elementwise_add") {
elemntwise_add_detected = true;
break;
}
}
PADDLE_ENFORCE_EQ(
elemntwise_add_detected, true,
platform::errors::InvalidArgument(
"When the FusedElemwiseAddActivationOp Is used in fused pass, the "
"elementwise_add Op must be"
"detected and used, Please check the fuse pass pattern"));
}
};
class FusedElemwiseAddActivationOpGrad : public FusedElemwiseActivationOpGrad {
public:
using FusedElemwiseActivationOpGrad::FusedElemwiseActivationOpGrad;
void InferShape(framework::InferShapeContext *ctx) const override {
FusedElemwiseActivationOpGrad::InferShape(ctx);
std::vector<std::string> functor_names =
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
bool elemntwise_add_grad_detected = false;
for (auto names : functor_names) {
if (names == "elementwise_add_grad") {
elemntwise_add_grad_detected = true;
break;
}
}
PADDLE_ENFORCE_EQ(
elemntwise_add_grad_detected, true,
platform::errors::InvalidArgument(
"When the FusedElemwiseAddActivationOpGrad Is used in fused pass, "
"the elementwise_add_grad Op must be"
"detected and used, Please check the fuse pass pattern"));
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
FusedElemwiseAddActivationNoNeddBufVarInferer, "X", "Y");
} // namespace operators
......@@ -397,13 +453,13 @@ REGISTER_OP_CPU_KERNEL(
// for memory optimization, we register the fused_elemwise_add_activation OP
REGISTER_OPERATOR(
fused_elemwise_add_activation, ops::FusedElemwiseActivationOp,
ops::FusedElemwiseActivationMaker,
ops::FusedElemwiseActivationGradMaker<paddle::framework::OpDesc>,
ops::FusedElemwiseActivationGradMaker<paddle::imperative::OpBase>);
fused_elemwise_add_activation, ops::FusedElemwiseAddActivationOp,
ops::FusedElemwiseAddActivationMaker,
ops::FusedElemwiseAddActivationGradMaker<paddle::framework::OpDesc>,
ops::FusedElemwiseAddActivationGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_elemwise_add_activation_grad,
ops::FusedElemwiseAddActivationNoNeddBufVarInferer,
ops::FusedElemwiseActivationOpGrad);
ops::FusedElemwiseAddActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
fused_elemwise_add_activation,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册