“c387fba08ce8face590793018b9950379120c40e”上不存在“release/0.10.0/doc/tutorials/rec/ml_regression_en.html”
未验证 提交 d6c8ca82 编写于 作者: R RedContritio 提交者: GitHub

support auto generate for activation_op swish (#53983)

上级 4f56e7c2
...@@ -176,21 +176,6 @@ $$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$$ ...@@ -176,21 +176,6 @@ $$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$$
} }
}; };
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of Swish operator");
AddOutput("Out", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
AddComment(R"DOC(
Swish Activation Operator.
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
)DOC");
}
};
class MishOpMaker : public framework::OpProtoAndCheckerMaker { class MishOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -406,7 +391,6 @@ FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP); ...@@ -406,7 +391,6 @@ FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu) REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu)
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor); REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
/* ========================== register checkpoint ===========================*/ /* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(leaky_relu) REGISTER_OP_VERSION(leaky_relu)
......
...@@ -2383,6 +2383,10 @@ ...@@ -2383,6 +2383,10 @@
- op : swish - op : swish
backward : swish_grad backward : swish_grad
inputs :
x : X
outputs :
out : Out
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
......
...@@ -121,6 +121,17 @@ ...@@ -121,6 +121,17 @@
data_type : out_grad data_type : out_grad
no_need_buffer : x no_need_buffer : x
- backward_op : swish_grad
forward : swish (Tensor x, float beta = 1.0f) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : GeneralUnaryGradInferMeta
param : [x]
kernel :
func : swish_grad
inplace : (out_grad -> x_grad)
- backward_op : tril_triu_grad - backward_op : tril_triu_grad
forward : tril_triu (Tensor x, int diagonal = 0, bool lower = false) -> Tensor(out) forward : tril_triu (Tensor x, int diagonal = 0, bool lower = false) -> Tensor(out)
args : (Tensor out_grad, int diagonal, bool lower) args : (Tensor out_grad, int diagonal, bool lower)
......
...@@ -398,6 +398,16 @@ ...@@ -398,6 +398,16 @@
param : [x, axes, starts, ends, strides] param : [x, axes, starts, ends, strides]
backward : strided_slice_grad backward : strided_slice_grad
- op : swish
args : (Tensor x, float beta = 1.0f)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : swish_raw
backward : swish_grad
- op : tril_indices - op : tril_indices
args : (int rows = 0, int cols = 0, int offset = 0, DataType dtype = DataType::INT64) args : (int rows = 0, int cols = 0, int offset = 0, DataType dtype = DataType::INT64)
output : Tensor(out) output : Tensor(out)
......
...@@ -42,26 +42,14 @@ namespace phi { ...@@ -42,26 +42,14 @@ namespace phi {
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardTanh, "hardtanh", "t_min" comma "t_max"); DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardTanh, "hardtanh", "t_min" comma "t_max");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Mish, "mish", "threshold"); DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Mish, "mish", "threshold");
KernelSignature SwishGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("swish_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature HardSwishOpArgumentMapping( KernelSignature HardSwishOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) { const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("hardswish", {"X"}, {}, {"Out"}); return KernelSignature("hardswish", {"X"}, {}, {"Out"});
} }
KernelSignature SwishOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("swish_raw", {"X"}, {"beta"}, {"Out"});
}
} // namespace phi } // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(hard_swish, hardswish); PD_REGISTER_BASE_KERNEL_NAME(hard_swish, hardswish);
PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(hard_swish, phi::HardSwishOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(hard_swish, phi::HardSwishOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(swish, phi::SwishOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册