未验证 提交 06736921 编写于 作者: R RedContritio 提交者: GitHub

support auto generate for activation_op hardswish (#53989)

上级 e531bb02
...@@ -83,22 +83,22 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -83,22 +83,22 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
} }
}; };
class HardSwishCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { // class HardSwishCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
public: // public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; // using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
protected: // protected:
void Apply() override { // void Apply() override {
paddle::Tensor x = this->GetSingleForwardInput("X"); // paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); // paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor dx = this->GetSingleInputGrad("X"); // paddle::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx); // auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx); // std::string dx_name = this->GetOutputName(dx);
VLOG(6) << "Runing hardswish_grad composite func"; // VLOG(6) << "Runing hardswish_grad composite func";
prim::hardswish_grad<prim::DescTensor>(x, out_grad, dx_ptr); // prim::hardswish_grad<prim::DescTensor>(x, out_grad, dx_ptr);
this->RecoverOutputName(dx, dx_name); // this->RecoverOutputName(dx, dx_name);
} // }
}; // };
phi::KernelKey GetKernelType(const framework::ExecutionContext& ctx, phi::KernelKey GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper, const framework::OperatorWithKernel& oper,
...@@ -217,32 +217,6 @@ Mish Activation Operator. ...@@ -217,32 +217,6 @@ Mish Activation Operator.
} }
}; };
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of HardSwish operator");
AddOutput("Out", "Output of HardSwish operator");
AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
.SetDefault(6.0f);
AddAttr<float>("scale", "The scale parameter of HardSwish operator")
.SetDefault(6.0f);
AddAttr<float>("offset", "The offset parameter of HardSwish operator")
.SetDefault(3.0f);
AddComment(R"DOC(
HardSwish Activation Operator.
The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).
$$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$$
The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.
)DOC");
}
};
template <ActBwdOpFwdDeps kDepValue> template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad : public framework::OperatorWithKernel { class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
public: public:
...@@ -432,10 +406,6 @@ FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP); ...@@ -432,10 +406,6 @@ FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu) REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu)
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor); REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
REGISTER_ACTIVATION_OP_WITH_COMP(hard_swish,
HardSwish,
HardSwishFunctor,
HardSwishGradFunctor);
REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor); REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
/* ========================== register checkpoint ===========================*/ /* ========================== register checkpoint ===========================*/
......
...@@ -1100,9 +1100,10 @@ ...@@ -1100,9 +1100,10 @@
x : X x : X
outputs : outputs :
out : Out out : Out
backward : hard_swish_grad backward : hardswish_grad (hard_swish_grad)
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
manual_signature : [hardswish]
- op : hardtanh (brelu) - op : hardtanh (brelu)
backward : hardtanh_grad (brelu_grad) backward : hardtanh_grad (brelu_grad)
......
...@@ -43,6 +43,18 @@ ...@@ -43,6 +43,18 @@
func : frobenius_norm_grad func : frobenius_norm_grad
param : [x, out, out_grad, axis, keepdim, reduce_all] param : [x, out, out_grad, axis, keepdim, reduce_all]
- backward_op : hardswish_grad
forward : hardswish (Tensor x, float threshold = 6.0f, float scale = 6.0f, float offset = 3.0f) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : hardswish_grad
param : [x, out_grad]
inplace : (out_grad -> x_grad)
- backward_op : relu6_grad - backward_op : relu6_grad
forward : relu6 (Tensor x, float threshold = 6.0f) -> Tensor(out) forward : relu6 (Tensor x, float threshold = 6.0f) -> Tensor(out)
args : (Tensor out, Tensor out_grad) args : (Tensor out, Tensor out_grad)
......
...@@ -180,6 +180,17 @@ ...@@ -180,6 +180,17 @@
backend : x backend : x
force_backend : force_cpu force_backend : force_cpu
- op : hardswish
args : (Tensor x, float threshold = 6.0f, float scale = 6.0f, float offset = 3.0f)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : hardswish
param : [x]
backward : hardswish_grad
- op : less_equal - op : less_equal
args : (Tensor x, Tensor y, int axis = -1, bool force_cpu=false) args : (Tensor x, Tensor y, int axis = -1, bool force_cpu=false)
output : Tensor(out) output : Tensor(out)
......
...@@ -47,11 +47,6 @@ KernelSignature SwishGradOpArgumentMapping( ...@@ -47,11 +47,6 @@ KernelSignature SwishGradOpArgumentMapping(
return KernelSignature("swish_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); return KernelSignature("swish_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
} }
KernelSignature HardSwishGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("hardswish_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature HardSwishOpArgumentMapping( KernelSignature HardSwishOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) { const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("hardswish", {"X"}, {}, {"Out"}); return KernelSignature("hardswish", {"X"}, {}, {"Out"});
...@@ -65,12 +60,8 @@ KernelSignature SwishOpArgumentMapping( ...@@ -65,12 +60,8 @@ KernelSignature SwishOpArgumentMapping(
} // namespace phi } // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(hard_swish, hardswish); PD_REGISTER_BASE_KERNEL_NAME(hard_swish, hardswish);
PD_REGISTER_BASE_KERNEL_NAME(hard_swish_grad, hardswish_grad);
PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(hard_swish_grad,
phi::HardSwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(hard_swish, phi::HardSwishOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(hard_swish, phi::HardSwishOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(swish, phi::SwishOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(swish, phi::SwishOpArgumentMapping);
...@@ -433,9 +433,9 @@ def hard_swish_composite(x): ...@@ -433,9 +433,9 @@ def hard_swish_composite(x):
maxmum(x + offset, 0), threshold maxmum(x + offset, 0), threshold
) * x / scale ) * x / scale
""" """
offset = 3.0
threshold = 6.0 threshold = 6.0
scale = 6.0 scale = 6.0
offset = 3.0
full_shape = x.shape if len(x.shape) == 0 else [1] full_shape = x.shape if len(x.shape) == 0 else [1]
res = ( res = (
minimum( minimum(
......
...@@ -393,10 +393,16 @@ def hardswish(x, name=None): ...@@ -393,10 +393,16 @@ def hardswish(x, name=None):
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'hardswish' x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'hardswish'
) )
threshold = 6.0
scale = 6.0
offset = 3.0
helper = LayerHelper('hardswish', **locals()) helper = LayerHelper('hardswish', **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op( helper.append_op(
type='hard_swish', inputs={'X': x}, outputs={'Out': out} type='hard_swish',
inputs={'X': x},
outputs={'Out': out},
attrs={'threshold': threshold, 'scale': scale, 'offset': offset},
) )
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册