未验证 提交 06736921 编写于 作者: R RedContritio 提交者: GitHub

support auto generate for activation_op hardswish (#53989)

上级 e531bb02
......@@ -83,22 +83,22 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
}
}
};
class HardSwishCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
protected:
void Apply() override {
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
VLOG(6) << "Runing hardswish_grad composite func";
prim::hardswish_grad<prim::DescTensor>(x, out_grad, dx_ptr);
this->RecoverOutputName(dx, dx_name);
}
};
// class HardSwishCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
// public:
// using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
// protected:
// void Apply() override {
// paddle::Tensor x = this->GetSingleForwardInput("X");
// paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
// paddle::Tensor dx = this->GetSingleInputGrad("X");
// auto* dx_ptr = this->GetOutputPtr(&dx);
// std::string dx_name = this->GetOutputName(dx);
// VLOG(6) << "Runing hardswish_grad composite func";
// prim::hardswish_grad<prim::DescTensor>(x, out_grad, dx_ptr);
// this->RecoverOutputName(dx, dx_name);
// }
// };
phi::KernelKey GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper,
......@@ -217,32 +217,6 @@ Mish Activation Operator.
}
};
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of HardSwish operator");
AddOutput("Out", "Output of HardSwish operator");
AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
.SetDefault(6.0f);
AddAttr<float>("scale", "The scale parameter of HardSwish operator")
.SetDefault(6.0f);
AddAttr<float>("offset", "The offset parameter of HardSwish operator")
.SetDefault(3.0f);
AddComment(R"DOC(
HardSwish Activation Operator.
The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).
$$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$$
The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.
)DOC");
}
};
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
public:
......@@ -432,10 +406,6 @@ FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu)
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
REGISTER_ACTIVATION_OP_WITH_COMP(hard_swish,
HardSwish,
HardSwishFunctor,
HardSwishGradFunctor);
REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
/* ========================== register checkpoint ===========================*/
......
......@@ -1100,9 +1100,10 @@
x : X
outputs :
out : Out
backward : hard_swish_grad
backward : hardswish_grad (hard_swish_grad)
extra :
attrs : [bool use_mkldnn = false]
manual_signature : [hardswish]
- op : hardtanh (brelu)
backward : hardtanh_grad (brelu_grad)
......
......@@ -43,6 +43,18 @@
func : frobenius_norm_grad
param : [x, out, out_grad, axis, keepdim, reduce_all]
- backward_op : hardswish_grad
forward : hardswish (Tensor x, float threshold = 6.0f, float scale = 6.0f, float offset = 3.0f) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : hardswish_grad
param : [x, out_grad]
inplace : (out_grad -> x_grad)
- backward_op : relu6_grad
forward : relu6 (Tensor x, float threshold = 6.0f) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
......
......@@ -180,6 +180,17 @@
backend : x
force_backend : force_cpu
- op : hardswish
args : (Tensor x, float threshold = 6.0f, float scale = 6.0f, float offset = 3.0f)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : hardswish
param : [x]
backward : hardswish_grad
- op : less_equal
args : (Tensor x, Tensor y, int axis = -1, bool force_cpu=false)
output : Tensor(out)
......
......@@ -47,11 +47,6 @@ KernelSignature SwishGradOpArgumentMapping(
return KernelSignature("swish_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature HardSwishGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("hardswish_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature HardSwishOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("hardswish", {"X"}, {}, {"Out"});
......@@ -65,12 +60,8 @@ KernelSignature SwishOpArgumentMapping(
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(hard_swish, hardswish);
PD_REGISTER_BASE_KERNEL_NAME(hard_swish_grad, hardswish_grad);
PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(hard_swish_grad,
phi::HardSwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(hard_swish, phi::HardSwishOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(swish, phi::SwishOpArgumentMapping);
......@@ -433,9 +433,9 @@ def hard_swish_composite(x):
maxmum(x + offset, 0), threshold
) * x / scale
"""
offset = 3.0
threshold = 6.0
scale = 6.0
offset = 3.0
full_shape = x.shape if len(x.shape) == 0 else [1]
res = (
minimum(
......
......@@ -393,10 +393,16 @@ def hardswish(x, name=None):
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'hardswish'
)
threshold = 6.0
scale = 6.0
offset = 3.0
helper = LayerHelper('hardswish', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='hard_swish', inputs={'X': x}, outputs={'Out': out}
type='hard_swish',
inputs={'X': x},
outputs={'Out': out},
attrs={'threshold': threshold, 'scale': scale, 'offset': offset},
)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册