未验证 提交 d6c8ca82 编写于 作者: R RedContritio 提交者: GitHub

support auto generate for activation_op swish (#53983)

上级 4f56e7c2
......@@ -176,21 +176,6 @@ $$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$$
}
};
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of Swish operator");
AddOutput("Out", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
AddComment(R"DOC(
Swish Activation Operator.
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
)DOC");
}
};
class MishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -406,7 +391,6 @@ FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu)
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(leaky_relu)
......
......@@ -2383,6 +2383,10 @@
- op : swish
backward : swish_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false]
......
......@@ -121,6 +121,17 @@
data_type : out_grad
no_need_buffer : x
- backward_op : swish_grad
forward : swish (Tensor x, float beta = 1.0f) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : GeneralUnaryGradInferMeta
param : [x]
kernel :
func : swish_grad
inplace : (out_grad -> x_grad)
- backward_op : tril_triu_grad
forward : tril_triu (Tensor x, int diagonal = 0, bool lower = false) -> Tensor(out)
args : (Tensor out_grad, int diagonal, bool lower)
......
......@@ -398,6 +398,16 @@
param : [x, axes, starts, ends, strides]
backward : strided_slice_grad
- op : swish
args : (Tensor x, float beta = 1.0f)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : swish_raw
backward : swish_grad
- op : tril_indices
args : (int rows = 0, int cols = 0, int offset = 0, DataType dtype = DataType::INT64)
output : Tensor(out)
......
......@@ -42,26 +42,14 @@ namespace phi {
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardTanh, "hardtanh", "t_min" comma "t_max");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Mish, "mish", "threshold");
KernelSignature SwishGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("swish_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature HardSwishOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("hardswish", {"X"}, {}, {"Out"});
}
KernelSignature SwishOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("swish_raw", {"X"}, {"beta"}, {"Out"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(hard_swish, hardswish);
PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(hard_swish, phi::HardSwishOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(swish, phi::SwishOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册