未验证 提交 dabfbba9 编写于 作者: R RedContritio 提交者: GitHub

support auto generate for activation op relu6 (#53979)

* support auto generate for activation_op relu6

* add generated_static_op for activation_op in CMakeLists.txt
上级 c558dc30
...@@ -460,6 +460,7 @@ if(WITH_MKLDNN) ...@@ -460,6 +460,7 @@ if(WITH_MKLDNN)
conv_transpose_op conv_transpose_op
batch_norm_op batch_norm_op
generated_op generated_op
generated_static_op
activation_op activation_op
elementwise_add_op elementwise_add_op
concat_and_split concat_and_split
......
...@@ -60,7 +60,8 @@ if(WITH_TESTING) ...@@ -60,7 +60,8 @@ if(WITH_TESTING)
mul_op mul_op
activation_op activation_op
elementwise_add_op elementwise_add_op
generated_op) generated_op
generated_static_op)
set_tests_properties(build_cinn_pass_test PROPERTIES LABELS "RUN_TYPE=CINN") set_tests_properties(build_cinn_pass_test PROPERTIES LABELS "RUN_TYPE=CINN")
target_link_libraries(build_cinn_pass_test ${PYTHON_LIBRARIES}) target_link_libraries(build_cinn_pass_test ${PYTHON_LIBRARIES})
......
...@@ -176,27 +176,6 @@ $$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$$ ...@@ -176,27 +176,6 @@ $$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$$
} }
}; };
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"Input of relu6 operator, an N-D Tensor, "
"with data type float32, float64.");
AddOutput(
"Out",
"Output of relu6 operator, a Tensor with the same shape as input.");
AddAttr<float>("threshold",
"The threshold value of Relu6. Default is 6.0. ")
.SetDefault(6.0f);
AddComment(R"DOC(
Relu6 Activation Operator.
$$out = \min(\max(0, x), threshold)$$
)DOC");
}
};
class SwishOpMaker : public framework::OpProtoAndCheckerMaker { class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -452,7 +431,6 @@ FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP); ...@@ -452,7 +431,6 @@ FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu) REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu)
REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor);
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor); REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
REGISTER_ACTIVATION_OP_WITH_COMP(hard_swish, REGISTER_ACTIVATION_OP_WITH_COMP(hard_swish,
HardSwish, HardSwish,
......
...@@ -1902,6 +1902,10 @@ ...@@ -1902,6 +1902,10 @@
- op : relu6 - op : relu6
backward : relu6_grad backward : relu6_grad
inputs :
x : X
outputs :
out : Out
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
......
...@@ -43,6 +43,17 @@ ...@@ -43,6 +43,17 @@
func : frobenius_norm_grad func : frobenius_norm_grad
param : [x, out, out_grad, axis, keepdim, reduce_all] param : [x, out, out_grad, axis, keepdim, reduce_all]
- backward_op : relu6_grad
forward : relu6 (Tensor x, float threshold = 6.0f) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out]
kernel :
func : relu6_grad
inplace : (out_grad -> x_grad)
- backward_op : rnn_grad - backward_op : rnn_grad
forward : rnn (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, float dropout_prob=0.0, bool is_bidirec=false, int input_size=10, int hidden_size=100, int num_layers=1, str mode="RNN_TANH", int seed=0, bool is_test=false) -> Tensor(out), Tensor(dropout_state_out), Tensor[](state), Tensor(reserve) forward : rnn (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, float dropout_prob=0.0, bool is_bidirec=false, int input_size=10, int hidden_size=100, int num_layers=1, str mode="RNN_TANH", int seed=0, bool is_test=false) -> Tensor(out), Tensor(dropout_state_out), Tensor[](state), Tensor(reserve)
args : (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor out, Tensor dropout_state_out, Tensor reserve, Tensor out_grad, Tensor[] state_grad, float dropout_prob, bool is_bidirec, int input_size, int hidden_size, int num_layers, str mode, int seed, bool is_test) args : (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor out, Tensor dropout_state_out, Tensor reserve, Tensor out_grad, Tensor[] state_grad, float dropout_prob, bool is_bidirec, int input_size, int hidden_size, int num_layers, str mode, int seed, bool is_test)
......
...@@ -313,6 +313,16 @@ ...@@ -313,6 +313,16 @@
func : reduce_scatter func : reduce_scatter
param: [x, nranks] param: [x, nranks]
- op : relu6
args : (Tensor x, float threshold = 6.0f)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : relu6_raw
backward : relu6_grad
- op : rnn - op : rnn
args: (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, float dropout_prob=0.0, bool is_bidirec=false, int input_size=10, int hidden_size=100, int num_layers=1, str mode="RNN_TANH", int seed=0, bool is_test=false) args: (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, float dropout_prob=0.0, bool is_bidirec=false, int input_size=10, int hidden_size=100, int num_layers=1, str mode="RNN_TANH", int seed=0, bool is_test=false)
output: Tensor(out), Tensor(dropout_state_out), Tensor[](state){pre_state.size()}, Tensor(reserve) output: Tensor(out), Tensor(dropout_state_out), Tensor[](state){pre_state.size()}, Tensor(reserve)
......
...@@ -47,11 +47,6 @@ KernelSignature SwishGradOpArgumentMapping( ...@@ -47,11 +47,6 @@ KernelSignature SwishGradOpArgumentMapping(
return KernelSignature("swish_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); return KernelSignature("swish_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
} }
KernelSignature Relu6GradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("relu6_grad", {"Out", "Out@GRAD"}, {}, {"X@GRAD"});
}
KernelSignature HardSwishGradOpArgumentMapping( KernelSignature HardSwishGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) { const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("hardswish_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); return KernelSignature("hardswish_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
...@@ -67,11 +62,6 @@ KernelSignature SwishOpArgumentMapping( ...@@ -67,11 +62,6 @@ KernelSignature SwishOpArgumentMapping(
return KernelSignature("swish_raw", {"X"}, {"beta"}, {"Out"}); return KernelSignature("swish_raw", {"X"}, {"beta"}, {"Out"});
} }
KernelSignature Relu6OpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("relu6_raw", {"X"}, {"threshold"}, {"Out"});
}
} // namespace phi } // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(hard_swish, hardswish); PD_REGISTER_BASE_KERNEL_NAME(hard_swish, hardswish);
...@@ -79,8 +69,6 @@ PD_REGISTER_BASE_KERNEL_NAME(hard_swish_grad, hardswish_grad); ...@@ -79,8 +69,6 @@ PD_REGISTER_BASE_KERNEL_NAME(hard_swish_grad, hardswish_grad);
PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(relu6_grad, phi::Relu6GradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(relu6, phi::Relu6OpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(hard_swish_grad, PD_REGISTER_ARG_MAPPING_FN(hard_swish_grad,
phi::HardSwishGradOpArgumentMapping); phi::HardSwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(hard_swish, phi::HardSwishOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(hard_swish, phi::HardSwishOpArgumentMapping);
......
...@@ -42,6 +42,7 @@ cc_test_old( ...@@ -42,6 +42,7 @@ cc_test_old(
crop_op crop_op
activation_op activation_op
generated_op generated_op
generated_static_op
phi phi
transpose_op transpose_op
fused_transpose_op fused_transpose_op
......
...@@ -15,6 +15,7 @@ if(WITH_TESTING AND NOT WIN32) ...@@ -15,6 +15,7 @@ if(WITH_TESTING AND NOT WIN32)
feed_op feed_op
fetch_op fetch_op
generated_op generated_op
generated_static_op
transfer_layout_op transfer_layout_op
jit_layer) jit_layer)
cc_test( cc_test(
......
...@@ -31,6 +31,7 @@ if(WITH_GPU ...@@ -31,6 +31,7 @@ if(WITH_GPU
elementwise_max_op elementwise_max_op
elementwise_div_op elementwise_div_op
generated_op generated_op
generated_static_op
squared_l2_norm_op squared_l2_norm_op
memcpy_h2d_op memcpy_h2d_op
memcpy_d2h_op memcpy_d2h_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册